Пример #1
0
def train(config):
    # set pre-process
    prep_config = config["prep"]
    prep_dict = {}
    prep_dict["source"] = prep.image_train(**config["prep"]['params'])
    prep_dict["target"] = prep.image_train(**config["prep"]['params'])
    if prep_config["test_10crop"]:
        prep_dict["test"] = prep.image_test_10crop(**config["prep"]['params'])
    else:
        prep_dict["test"] = prep.image_test(**config["prep"]['params'])

    # prepare data
    dsets = {}
    dset_loaders = {}
    data_config = config["data"]
    train_bs = data_config["source"]["batch_size"]
    test_bs = data_config["test"]["batch_size"]
    dsets["source"] = datasets.ImageFolder(data_config['source']['list_path'], transform=prep_dict["source"])
    dset_loaders['source'] = getdataloader(dsets['source'], batchsize=train_bs, num_workers=4, drop_last=True, weightsampler=True)
    dsets["target"] = datasets.ImageFolder(data_config['target']['list_path'], transform=prep_dict["target"])
    dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs,
                                        shuffle=True, num_workers=4, drop_last=True)

    if prep_config["test_10crop"]:
        for i in range(10):
            dsets["test"] = [datasets.ImageFolder(data_config['test']['list_path'],
                                                  transform=prep_dict["test"][i]) for i in range(10)]
            dset_loaders["test"] = [DataLoader(dset, batch_size=test_bs,
                                               shuffle=False, num_workers=4) for dset in dsets['test']]
    else:
        dsets["test"] = datasets.ImageFolder(data_config['test']['list_path'],
                                             transform=prep_dict["test"])
        dset_loaders["test"] = DataLoader(dsets["test"], batch_size=test_bs,
                                          shuffle=False, num_workers=4)

    class_num = config["network"]["params"]["class_num"]

    # set base network
    net_config = config["network"]
    base_network = net_config["name"](**net_config["params"])
    base_network = base_network.cuda()

    # set test_ad_net
    test_ad_net = network.AdversarialNetwork(base_network.output_num(), 1024, test_ad_net=True)
    test_ad_net = test_ad_net.cuda()

    # add additional network for some methods
    if config['method'] == 'DANN':
        random_layer = None
        ad_net = network.AdversarialNetwork(base_network.output_num(), 1024)
    elif config['method'] == 'MADA':
        random_layer = None
        ad_net = network.AdversarialNetworkClassGroup(base_network.output_num(), 1024, class_num)
    elif config['method'] == 'proposed':
        if config['loss']['random']:
            random_layer = network.RandomLayer([base_network.output_num(), class_num], config['loss']['random_dim'])
            ad_net = network.AdversarialNetwork(config['loss']['random_dim'], 1024)
            ad_net_group = network.AdversarialNetworkGroup(config['loss']['random_dim'], 256, class_num, config['center_threshold'])
        else:
            random_layer = None
            ad_net = network.AdversarialNetwork(base_network.output_num(), 1024)
            ad_net_group = network.AdversarialNetworkGroup(base_network.output_num(), 1024, class_num, config['center_threshold'])
    elif config['method'] == 'base':
        pass
    else:
        if config["loss"]["random"]:
            random_layer = network.RandomLayer([base_network.output_num(), class_num], config["loss"]["random_dim"])
            ad_net = network.AdversarialNetwork(config["loss"]["random_dim"], 1024)
        else:
            random_layer = None
            ad_net = network.AdversarialNetwork(base_network.output_num() * class_num, 1024)
    if config["loss"]["random"] and config['method'] != 'base' and config['method'] != 'DANN' and config['method'] != 'MADA':
        random_layer.cuda()
    if config['method'] != 'base':
        ad_net = ad_net.cuda()
    if config['method'] == 'proposed':
        ad_net_group = ad_net_group.cuda()

    # set parameters
    if config['method'] == 'proposed':
        parameter_list = base_network.get_parameters() + test_ad_net.get_parameters() + ad_net.get_parameters() + ad_net_group.get_parameters()
    elif config['method'] == 'base':
        parameter_list = base_network.get_parameters() + test_ad_net.get_parameters()
    elif config['method'] == 'MADA':
        parameter_list = base_network.get_parameters() + test_ad_net.get_parameters() + ad_net.get_parameters()
    else:
        parameter_list = base_network.get_parameters() + test_ad_net.get_parameters() + ad_net.get_parameters()

    # set optimizer
    optimizer_config = config["optimizer"]
    optimizer = optimizer_config["type"](parameter_list, **(optimizer_config["optim_params"]))
    param_lr = []
    for param_group in optimizer.param_groups:
        param_lr.append(param_group["lr"])
    schedule_param = optimizer_config["lr_param"]
    lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]]

    # parallel
    gpus = config['gpu'].split(',')
    if len(gpus) > 1:
        base_network = nn.DataParallel(base_network)
        test_ad_net = nn.DataParallel(test_ad_net)
        if config['method'] == 'DANN':
            ad_net = nn.DataParallel(ad_net)
        elif config['method'] == 'proposed':
            if config['loss']['random']:
                random_layer = nn.DataParallel(random_layer)
                ad_net = nn.DataParallel(ad_net)
                #将ad_net_group设置成并行将会引发error,原因可能是由于ad_net_group的输出不是tensor类型,parallel还不能支持。
                #ad_net_group = nn.DataParallel(ad_net_group)
            else:
                ad_net = nn.DataParallel(ad_net)
                #ad_net_group = nn.DataParallel(ad_net_group)
        elif config['method'] == 'base':
            pass
        else:
            # CDAN+E
            if config["loss"]["random"]:
                random_layer = nn.DataParallel(random_layer)
                ad_net = nn.DataParallel(ad_net)
            # CDAN
            else:
                ad_net = nn.DataParallel(ad_net)

    ## train
    len_train_source = len(dset_loaders["source"])
    len_train_target = len(dset_loaders["target"])
    transfer_loss_value = classifier_loss_value = total_loss_value = 0.0
    best_acc = 0.0
    for i in range(config["num_iterations"]):
        if i % config["test_interval"] == config["test_interval"] - 1:
            base_network.train(False)  # eval() == train(False) is True
            temp_acc = image_classification_test(dset_loaders, base_network, test_10crop=prep_config["test_10crop"])
            temp_model = nn.Sequential(base_network)
            if temp_acc > best_acc:
                best_acc = temp_acc
                best_model = temp_model
            log_str = "iter: {:05d}, precision: {:.5f}".format(i, temp_acc)
            config["out_file"].write(log_str + "\n")
            config["out_file"].flush()
            print(log_str)
        # if i % config["snapshot_interval"] == 0:
        #     torch.save(nn.Sequential(base_network), osp.join(config["output_path"],
        #                                                      "iter_{:05d}_model.pth.tar".format(i)))

        loss_params = config["loss"]
        # train one iter
        base_network.train(True)
        if config['method'] != 'base':
            ad_net.train(True)
        if config['method'] == 'proposed':
            ad_net_group.train(True)
        # lr_scheduler
        optimizer = lr_scheduler(optimizer, i, **schedule_param)
        optimizer.zero_grad()
        if i % len_train_source == 0:
            iter_source = iter(dset_loaders["source"])
        if i % len_train_target == 0:
            iter_target = iter(dset_loaders["target"])
        inputs_source, labels_source = iter_source.next()
        inputs_target, labels_target = iter_target.next()
        inputs_source, inputs_target, labels_source = inputs_source.cuda(), inputs_target.cuda(), labels_source.cuda()
        features_source, outputs_source = base_network(inputs_source)
        features_target, outputs_target = base_network(inputs_target)
        if config['tsne']:
            # feature visualization by using T-SNE
            if i == int(0.98*config['num_iterations']):
                features_source_total = features_source.cpu().detach().numpy()
                features_target_total = features_target.cpu().detach().numpy()
            elif i > int(0.98*config['num_iterations']) and i < int(0.98*config['num_iterations'])+10:
                features_source_total = np.concatenate((features_source_total, features_source.cpu().detach().numpy()))
                features_target_total = np.concatenate((features_target_total, features_target.cpu().detach().numpy()))
            elif i == int(0.98*config['num_iterations'])+10:
                for index in range(config['tsne_num']):
                    features_embeded = TSNE(perplexity=10,n_iter=5000).fit_transform(np.concatenate((features_source_total, features_target_total)))
                    fig = plt.figure()
                    plt.scatter(features_embeded[:len(features_embeded)//2, 0], features_embeded[:len(features_embeded)//2, 1], c='r', s=1)
                    plt.scatter(features_embeded[len(features_embeded)//2:, 0], features_embeded[len(features_embeded)//2:, 1], c='b', s=1)
                    plt.savefig(osp.join(config["output_path"], config['method']+'-'+str(index)+'.png'))
                    plt.close()
            else:
                pass

        assert features_source.size(0) == features_target.size(0), 'The batchsize must be same'
        assert outputs_source.size(0) == outputs_target.size(0), 'The batchsize must be same'
        # source first, target second
        features = torch.cat((features_source, features_target), dim=0)
        outputs = torch.cat((outputs_source, outputs_target), dim=0)

        # output the A_distance
        if i % config["test_interval"] == config["test_interval"] - 1:
            A_distance = cal_A_distance(test_ad_net, features)
            config['A_distance_file'].write(str(A_distance)+'\n')
            config['A_distance_file'].flush()

        softmax_out = nn.Softmax(dim=1)(outputs)
        if config['method'] == 'CDAN+E':
            entropy = loss.Entropy(softmax_out)
            transfer_loss = loss.CDAN([features, softmax_out], ad_net, entropy, network.calc_coeff(i), random_layer)
        elif config['method'] == 'CDAN':
            transfer_loss = loss.CDAN([features, softmax_out], ad_net, None, None, random_layer)
        elif config['method'] == 'DANN':
            transfer_loss = loss.DANN(features, ad_net)
        elif config['method'] == 'MADA':
            transfer_loss = loss.MADA(features, softmax_out, ad_net)
        elif config['method'] == 'proposed':
            entropy = loss.Entropy(softmax_out)
            transfer_loss = loss.proposed([features, outputs], labels_source, ad_net, ad_net_group, entropy,
                                          network.calc_coeff(i), i, random_layer, config['loss']['trade_off23'])
        elif config['method'] == 'base':
            pass
        else:
            raise ValueError('Method cannot be recognized.')
        test_domain_loss = loss.DANN(features.clone().detach(), test_ad_net)
        classifier_loss = nn.CrossEntropyLoss()(outputs_source, labels_source)
        if config['method'] == 'base':
            total_loss = classifier_loss + test_domain_loss
        else:
            total_loss = loss_params["trade_off"] * transfer_loss + classifier_loss + test_domain_loss
        total_loss.backward()
        optimizer.step()
    # torch.save(best_model, osp.join(config["output_path"], "best_model.pth.tar"))
    return best_acc
Пример #2
0
def train(config):
    ## set pre-process
    prep_dict = {}
    prep_config = config["prep"]
    prep_dict["source"] = prep.image_train(**config["prep"]['params'])
    prep_dict["target"] = prep.image_train(**config["prep"]['params'])
    if prep_config["test_10crop"]:
        prep_dict["test"] = prep.image_test_10crop(**config["prep"]['params'])
    else:
        prep_dict["test"] = prep.image_test(**config["prep"]['params'])

    ## prepare data
    dsets = {}
    dset_loaders = {}
    data_config = config["data"]
    train_bs = data_config["source"]["batch_size"]
    test_bs = data_config["test"]["batch_size"]
    dsets["source"] = ImageList(open(data_config["source"]["list_path"]).readlines(), \
                                transform=prep_dict["source"])
    dset_loaders["source"] = DataLoader(dsets["source"], batch_size=train_bs, \
                                        shuffle=True, num_workers=0, drop_last=True)
    dsets["target"] = ImageList(open(data_config["target"]["list_path"]).readlines(), \
                                transform=prep_dict["target"])
    dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, \
                                        shuffle=True, num_workers=0, drop_last=True)

    if prep_config["test_10crop"]:
        for i in range(10):
            dsets["test"] = [ImageList(open(data_config["test"]["list_path"]).readlines(), \
                                       transform=prep_dict["test"][i]) for i in range(10)]
            dset_loaders["test"] = [DataLoader(dset, batch_size=test_bs, \
                                               shuffle=False, num_workers=0) for dset in dsets['test']]
    else:
        dsets["test"] = ImageList(open(data_config["test"]["list_path"]).readlines(), \
                                  transform=prep_dict["test"])
        dset_loaders["test"] = DataLoader(dsets["test"], batch_size=test_bs, \
                                          shuffle=False, num_workers=0)

    # class_num = config["network"]["params"]["class_num"]
    n_class = config["network"]["params"]["class_num"]

    ## set base network
    net_config = config["network"]
    base_network_stu = net_config["name"](**net_config["params"]).cuda()
    base_network_tea = net_config["name"](**net_config["params"]).cuda()

    ## add additional network for some methods
    if config["loss"]["random"]:
        random_layer = network.RandomLayer([base_network_stu.output_num(), n_class], config["loss"]["random_dim"])
        ad_net = network.AdversarialNetwork(config["loss"]["random_dim"], 1024)
    else:
        random_layer = None

        if config['method'] == 'DANN':
            ad_net = network.AdversarialNetwork(base_network_stu.output_num(), 1024)#DANN
        else:
            ad_net = network.AdversarialNetwork(base_network_stu.output_num() * n_class, 1024)

    if config["loss"]["random"]:
        random_layer.cuda()
    ad_net = ad_net.cuda()
    ad_net2 = network.AdversarialNetwork(n_class, n_class*4)
    ad_net2.cuda()

    parameter_list = base_network_stu.get_parameters() + ad_net.get_parameters()

    teacher_params = list(base_network_tea.parameters())
    for param in teacher_params:
        param.requires_grad = False

    ## set optimizer
    optimizer_config = config["optimizer"]
    optimizer = optimizer_config["type"](parameter_list, \
                                         **(optimizer_config["optim_params"]))

    teacher_optimizer = EMAWeightOptimizer(base_network_tea, base_network_stu, alpha=0.99)

    param_lr = []
    for param_group in optimizer.param_groups:
        param_lr.append(param_group["lr"])
    schedule_param = optimizer_config["lr_param"]
    lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]]

    ## train
    len_train_source = len(dset_loaders["source"])
    len_train_target = len(dset_loaders["target"])
    best_acc = 0.0

    output1(log_name)

    loss1, loss2, loss3, loss4, loss5,loss6 = 0, 0, 0, 0, 0,0

    output1('    =======    DA TRAINING    =======    ')

    best1 = 0
    f_t_result = []
    max_iter = config["num_iterations"]
    for i in range(max_iter+1):
        if i % config["test_interval"] == config["test_interval"] - 1 and i > 1500:
            base_network_tea.train(False)
            base_network_stu.train(False)

            # print("test")
            if 'MT' in config['method']:
                temp_acc = image_classification_test(dset_loaders,
                                                     base_network_tea, test_10crop=prep_config["test_10crop"])
                if temp_acc > best_acc:
                    best_acc = temp_acc

                log_str = "iter: {:05d}, tea_precision: {:.5f}".format(i, temp_acc)
                output1(log_str)
                #
                # if i > 20001 and best_acc < 0.69:
                #     break
                #
                # if temp_acc < 0.67:
                #     break
                #
                # if i > 30001 and best_acc < 0.71:
                #     break
                    # torch.save(base_network_tea, osp.join(path, "_model.pth.tar"))

                # temp_acc = image_classification_test(dset_loaders,
                #                                      base_network_stu, test_10crop=prep_config["test_10crop"])
                if temp_acc > best_acc:
                    best_acc = temp_acc
                    torch.save(base_network_stu, osp.join(path, "_model.pth.tar"))
                # log_str = "iter: {:05d}, stu_precision: {:.5f}".format(i, temp_acc)
                # output1(log_str)
            else:
                temp_acc = image_classification_test(dset_loaders,
                                                     base_network_stu, test_10crop=prep_config["test_10crop"])
                if temp_acc > best_acc:
                    best_acc = temp_acc
                    torch.save(base_network_stu, osp.join(path,"_model.pth.tar"))
                log_str = "iter: {:05d}, precision: {:.5f}".format(i, temp_acc)
                output1(log_str)

        loss_params = config["loss"]
        ## train one iter
        base_network_stu.train(True)
        base_network_tea.train(True)

        ad_net.train(True)
        optimizer = lr_scheduler(optimizer, i, **schedule_param)
        optimizer.zero_grad()

        if i % len_train_target == 0:
            iter_target = iter(dset_loaders["target"])
        if i % len_train_source == 0:
            iter_source = iter(dset_loaders["source"])

        inputs_source, labels_source = iter_source.next()
        inputs_target, labels_target = iter_target.next()

        inputs_source, inputs_target, labels_source = inputs_source.cuda(), inputs_target.cuda(), labels_source.cuda()
        features_source, outputs_source = base_network_stu(inputs_source)

        features_target_stu, outputs_target_stu = base_network_stu(inputs_target)
        features_target_tea, outputs_target_tea = base_network_tea(inputs_target)

        softmax_out_source = nn.Softmax(dim=1)(outputs_source)
        softmax_out_target_stu = nn.Softmax(dim=1)(outputs_target_stu)
        softmax_out_target_tea = nn.Softmax(dim=1)(outputs_target_tea)

        features = torch.cat((features_source, features_target_stu), dim=0)

        if 'MT' in config['method']:
            softmax_out = torch.cat((softmax_out_source, softmax_out_target_tea), dim=0)

        else:
            softmax_out = torch.cat((softmax_out_source, softmax_out_target_stu), dim=0)

        vat_loss = VAT(base_network_stu).cuda()

        n, d = features_source.shape
        decay = cal_decay(start=1,end=0.6,i = i)
        # image number in each class
        s_labels = labels_source
        t_max, t_labels = torch.max(softmax_out_target_tea, 1)
        t_max, t_labels = t_max.cuda(), t_labels.cuda()
        if config['method'] == 'DANN+dis' or config['method'] == 'CDRL':
            pass

        elif config['method'] == 'RESNET':
            classifier_loss = nn.CrossEntropyLoss()(outputs_source, labels_source)
            total_loss = classifier_loss

        elif config['method'] == 'CDAN+E':
            entropy = losssntg.Entropy(softmax_out)
            ad_loss = losssntg.CDANori([features, softmax_out], ad_net, entropy, network.calc_coeff(i),
                                          random_layer)

            transfer_loss = loss_params["trade_off"] * ad_loss
            classifier_loss = nn.CrossEntropyLoss()(outputs_source, labels_source)
            total_loss = transfer_loss + classifier_loss


        elif config['method'] == 'CDAN':
            ad_loss = losssntg.CDANori([features, softmax_out], ad_net, None, None, random_layer)

            transfer_loss = loss_params["trade_off"] * ad_loss
            classifier_loss = nn.CrossEntropyLoss()(outputs_source, labels_source)
            total_loss = transfer_loss + classifier_loss


        elif config['method'] == 'DANN':

            ad_loss = losssntg.DANN(features, ad_net)
            transfer_loss = loss_params["trade_off"] * ad_loss
            classifier_loss = nn.CrossEntropyLoss()(outputs_source, labels_source)
            total_loss = transfer_loss + classifier_loss


        elif config['method'] == 'CDAN+MT':
            th = config['th']

            ad_loss = losssntg.CDANori([features, softmax_out], ad_net, None, None, random_layer)
            unsup_loss = compute_aug_loss(softmax_out_target_stu, softmax_out_target_tea, n_class,confidence_thresh=th)
            unsup_loss = compute_aug_loss2(softmax_out_target_stu, softmax_out_target_tea, n_class,confidence_thresh=th)

            transfer_loss = loss_params["trade_off"] * ad_loss \
                            + 0.01 * unsup_loss
            classifier_loss = nn.CrossEntropyLoss()(outputs_source, labels_source)
            total_loss = transfer_loss + classifier_loss

        elif config['method'] == 'CDAN+MT+VAT':
            cent = ConditionalEntropyLoss().cuda()
            ad_loss = losssntg.CDANori([features, softmax_out], ad_net, None, None, random_layer)
            unsup_loss = compute_aug_loss(softmax_out_target_stu, softmax_out_target_tea, n_class)
            loss_trg_cent = 1e-2 * cent(outputs_target_stu)
            loss_trg_vat = 1e-2 * vat_loss(inputs_target, outputs_target_stu)
            transfer_loss = loss_params["trade_off"] * ad_loss \
                            + 0.001*(unsup_loss + loss_trg_cent + loss_trg_vat)
            classifier_loss = nn.CrossEntropyLoss()(outputs_source, labels_source)
            total_loss = transfer_loss + classifier_loss

        elif config['method'] == 'CDAN+MT+cent+VAT+temp':
            th = 0.7
            ad_loss = losssntg.CDANori([features, softmax_out], ad_net, None, None, random_layer)
            unsup_loss = compute_aug_loss(softmax_out_target_stu, softmax_out_target_tea, n_class,confidence_thresh=th)
            cent = ConditionalEntropyLoss().cuda()
            # loss_src_vat = vat_loss(inputs_source, outputs_source)
            loss_trg_cent = 1e-2 * cent(outputs_target_stu)
            loss_trg_vat = 1e-2 * vat_loss(inputs_target, outputs_target_stu)
            transfer_loss = loss_params["trade_off"] * ad_loss \
                            + unsup_loss + loss_trg_cent + loss_trg_vat
            # temperature
            classifier_loss = nn.NLLLoss()(nn.LogSoftmax(1)(outputs_source / 1.05), labels_source)
            total_loss = transfer_loss + classifier_loss

        elif config['method'] == 'CDAN+MT+cent+VAT+weightCross+T':

            if i % len_train_target == 0:

                if i != 0:
                    # print(cnt)
                    cnt = torch.tensor(cnt).float()
                    weight = cnt.sum() - cnt
                    weight = weight.cuda()
                else:
                    weight = torch.ones(n_class).cuda()

                cnt = [0] * n_class


            for j in t_labels:
                cnt[j.item()] += 1


            a = config['a']
            b = config['b']
            th = config['th']
            temp = config['temp']

            ad_loss = losssntg.CDANori([features, softmax_out], ad_net, None, None, random_layer)
            unsup_loss = compute_aug_loss2(softmax_out_target_stu, softmax_out_target_tea, n_class,confidence_thresh=th)
            # cbloss = compute_cbloss(softmax_out_target_stu, n_class, cls_balance=0.05)
            # unsup_loss = compute_aug_loss(softmax_out_target_stu, softmax_out_target_tea, n_class,confidence_thresh=th)

            cent = ConditionalEntropyLoss().cuda()

            # loss_src_vat = vat_loss(inputs_source, outputs_source)
            loss_trg_cent = 1e-2 * cent(outputs_target_stu)
            loss_trg_vat = 1e-2 * vat_loss(inputs_target, outputs_target_stu)
            classifier_loss = nn.CrossEntropyLoss(weight=weight)(outputs_source/temp, labels_source)
            # classifier_loss = nn.CrossEntropyLoss()(outputs_source, labels_source)
            transfer_loss = loss_params["trade_off"] * ad_loss \
                + a*unsup_loss + b*(loss_trg_vat+loss_trg_cent)

            total_loss = transfer_loss + classifier_loss

        elif config['method'] == 'CDAN+MT+E+VAT+weightCross+T':
            entropy = losssntg.Entropy(softmax_out)
            if i % len_train_target == 0:

                if i != 0:
                    # print(cnt)
                    cnt = torch.tensor(cnt).float()
                    weight = cnt.sum() - cnt
                    weight = weight.cuda()
                else:
                    weight = torch.ones(n_class).cuda()

                cnt = [0] * n_class

            for j in t_labels:
                cnt[j.item()] += 1

            a = config['a']
            b = config['b']
            th = config['th']
            temp = config['temp']

            ad_loss = losssntg.CDANori([features, softmax_out], ad_net, entropy, network.calc_coeff(i),
                                       random_layer)            # unsup_loss = compute_aug_loss2(softmax_out_target_stu, softmax_out_target_tea, n_class,confidence_thresh=th)
            # cbloss = compute_cbloss(softmax_out_target_stu, n_class, cls_balance=0.05)
            unsup_loss = compute_aug_loss(softmax_out_target_stu, softmax_out_target_tea, n_class, confidence_thresh=th)

            cent = ConditionalEntropyLoss().cuda()

            # loss_src_vat = vat_loss(inputs_source, outputs_source)
            loss_trg_cent = 1e-2 * cent(outputs_target_stu)
            loss_trg_vat = 1e-2 * vat_loss(inputs_target, outputs_target_stu)
            classifier_loss = nn.CrossEntropyLoss(weight=weight)(outputs_source / temp, labels_source)
            # classifier_loss = nn.CrossEntropyLoss()(outputs_source, labels_source)
            transfer_loss = loss_params["trade_off"] * ad_loss \
                            + a * unsup_loss + b * (loss_trg_vat + loss_trg_cent)

            total_loss = transfer_loss + classifier_loss

        # adout
        # ad_out1 = ad_net(features_source)
        # w = 1-ad_out1
        # c = w * nn.CrossEntropyLoss(reduction='none')(outputs_source, labels_source)
        # classifier_loss = c.mean()
        # total_loss = transfer_loss + classifier_loss
        total_loss.backward()
        optimizer.step()
        teacher_optimizer.step()

        loss1 += ad_loss.item()
        loss2 += classifier_loss.item()
        # loss3 += unsup_loss.item()
        # loss4 += loss_trg_cent.item()
        # loss5 += loss_trg_vat.item()
        # loss6 += cbloss.item()
        # dis_sloss_l += sloss_l.item()

        if i % 50 == 0 and i != 0:
            output1('iter:{:d}, ad_loss_D:{:.2f}, closs:{:.2f}, unsup_loss:{:.2f}, loss_trg_cent:{:.2f}, loss_trg_vatcd:{:.2f}, cbloss:{:.2f}'
                    .format(i, loss1, loss2, loss3, loss4, loss5,loss6))
            loss1, loss2, loss3, loss4, loss5, loss6 = 0, 0, 0, 0, 0, 0

    # torch.save(best_model, osp.join(path, "best_model.pth.tar"))
    return best_acc
Пример #3
0
def train(config):
    ## set pre-process
    prep_dict = {}
    prep_config = config["prep"]
    prep_dict["source"] = prep.image_train(**config["prep"]['params'])
    prep_dict["target"] = prep.image_train(**config["prep"]['params'])
    if prep_config["test_10crop"]:
        prep_dict["test"] = prep.image_test_10crop(**config["prep"]['params'])
    else:
        prep_dict["test"] = prep.image_test(**config["prep"]['params'])

    ## prepare data
    dsets = {}
    dset_loaders = {}
    data_config = config["data"]
    train_bs = data_config["source"]["batch_size"]
    test_bs = data_config["test"]["batch_size"]
    dsets["source"] = ImageList(open(data_config["source"]["list_path"]).readlines(), \
                                transform=prep_dict["source"])
    dset_loaders["source"] = DataLoader(dsets["source"], batch_size=train_bs, \
                                        shuffle=True, num_workers=4, drop_last=True)
    dsets["target"] = ImageList(open(data_config["target"]["list_path"]).readlines(), \
                                transform=prep_dict["target"])
    dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, \
                                        shuffle=True, num_workers=4, drop_last=True)

    if prep_config["test_10crop"]:
        for i in range(10):
            dsets["test"] = [ImageList(open(data_config["test"]["list_path"]).readlines(), \
                                       transform=prep_dict["test"][i]) for i in range(10)]
            dset_loaders["test"] = [DataLoader(dset, batch_size=test_bs, \
                                               shuffle=False, num_workers=4) for dset in dsets['test']]
    else:
        dsets["test"] = ImageList(open(data_config["test"]["list_path"]).readlines(), \
                                  transform=prep_dict["test"])
        dset_loaders["test"] = DataLoader(dsets["test"], batch_size=test_bs, \
                                          shuffle=False, num_workers=4)

    class_num = config["network"]["params"]["class_num"]
    crit = LabelSmoothingLoss(smoothing=0.1, classes=class_num)#标签平滑操作

    ## set base network
    net_config = config["network"]
    base_network = net_config["name"](**net_config["params"])
    base_network = base_network.cuda()  # 加载基础网络结构

    ## add additional network for some methods
    if config["loss"]["random"]:
        random_layer = network.RandomLayer([base_network.output_num(), class_num], config["loss"]["random_dim"])
        ad_net = network.AdversarialNetwork(config["loss"]["random_dim"], 1024)
    else:
        random_layer = None
        ad_net = network.AdversarialNetwork(base_network.output_num() * class_num, 1024)  # 对抗网络结构
    if config["loss"]["random"]:
        random_layer.cuda()
    ad_net = ad_net.cuda()
    parameter_list = base_network.get_parameters() + ad_net.get_parameters()

    ## set optimizer
    optimizer_config = config["optimizer"]
    optimizer = optimizer_config["type"](parameter_list, \
                                         **(optimizer_config["optim_params"]))

    #中心损失函数
    criterion_centor=CenterLoss(num_classes=class_num,feat_dim=256,use_gpu=True)
    optimizer_centerloss=torch.optim.SGD(criterion_centor.parameters(),lr=config['lr'])

    param_lr = []
    for param_group in optimizer.param_groups:
        param_lr.append(param_group["lr"])
    schedule_param = optimizer_config["lr_param"]
    lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]]

    gpus = config['gpu'].split(',')
    if len(gpus) > 1:
        ad_net = nn.DataParallel(ad_net, device_ids=[int(i) for i in gpus])
        base_network = nn.DataParallel(base_network, device_ids=[int(i) for i in gpus])

    ## train
    len_train_source = len(dset_loaders["source"])
    len_train_target = len(dset_loaders["target"])
    transfer_loss_value = classifier_loss_value = total_loss_value = 0.0
    best_acc = 0.0
    start_time = time.time()
    for i in range(config["num_iterations"]):

        if i % config["test_interval"] == config["test_interval"] - 1:
            # 在这里进行测试的工作
            base_network.train(False)
            temp_acc = image_classification_test(dset_loaders, \
                                                 base_network, test_10crop=prep_config["test_10crop"])
            temp_model = nn.Sequential(base_network)
            if temp_acc > best_acc:
                best_acc = temp_acc
                best_model = temp_model
            log_str = "iter: {:05d}, precision: {:.5f}".format(i, temp_acc)
            config["out_file"].write(log_str + "\n")
            config["out_file"].flush()
            print(log_str)
            end_time = time.time()
            print('iter {} cost time {:.4f} sec.'.format(i, end_time - start_time))  # 打印时间间隔
            start_time = time.time()

        if i % config["snapshot_interval"] == 0:
            torch.save(nn.Sequential(base_network), osp.join(config["output_path"], \
                                                             "iter_{:05d}_model.pth.tar".format(i)))

        loss_params = config["loss"]

        ## train one iter
        base_network.train(True)  # 训练模式
        ad_net.train(True)

        optimizer = lr_scheduler(optimizer, i, **schedule_param)
        # optimizer_centerloss=lr_scheduler(optimizer_centerloss, i, **schedule_param)

        optimizer.zero_grad()
        optimizer_centerloss.zero_grad()

        if i % len_train_source == 0:
            iter_source = iter(dset_loaders["source"])
        if i % len_train_target == 0:
            iter_target = iter(dset_loaders["target"])
        inputs_source, labels_source = iter_source.next()
        inputs_target, labels_target = iter_target.next()
        inputs_source, inputs_target, labels_source = inputs_source.cuda(), inputs_target.cuda(), labels_source.cuda()
        features_source, outputs_source = base_network(inputs_source)
        features_target, outputs_target = base_network(inputs_target)
        features = torch.cat((features_source, features_target), dim=0)
        outputs = torch.cat((outputs_source, outputs_target), dim=0)
        softmax_out = nn.Softmax(dim=1)(outputs)
        if config['method'] == 'CDAN+E':
            entropy = loss.Entropy(softmax_out)
            transfer_loss = loss.CDAN([features, softmax_out], ad_net, entropy, network.calc_coeff(i), random_layer)
        elif config['method'] == 'CDAN':
            transfer_loss = loss.CDAN([features, softmax_out], ad_net, None, None, random_layer)
        elif config['method'] == 'DANN':
            transfer_loss = loss.DANN(features, ad_net)
        else:
            raise ValueError('Method cannot be recognized.')
        # classifier_loss = nn.CrossEntropyLoss()(outputs_source, labels_source)  # 源域的分类损失
        classifier_loss = crit(outputs_source, labels_source)  # 源域的分类损失,标签平滑操作

        # 计算中心损失函数
        loss_centor = criterion_centor(features_source, labels_source)  # 中心损失计算

        total_loss = loss_params["trade_off"] * transfer_loss + classifier_loss + config['centor_w']*loss_centor
        if i % config["test_interval"] == config["test_interval"] - 1:
            print('total loss: {:.4f}, transfer loss: {:.4f}, classifier loss: {:.4f}, centor loss: {:.4f}'.format(
                total_loss.item(),transfer_loss.item(),classifier_loss.item(),config['centor_w']*loss_centor.item()
            ))
        total_loss.backward()
        optimizer.step()
        # by doing so, weight_cent would not impact on the learning of centers
        for param in criterion_centor.parameters():
            param.grad.data *= (1. / config['centor_w'])
        optimizer_centerloss.step()

    torch.save(best_model, osp.join(config["output_path"], "best_model.pth.tar"))
    return best_acc
Пример #4
0
def train(config):
    ####################################################
    # Tensorboard setting
    ####################################################
    #tensor_writer = SummaryWriter(config["tensorboard_path"])

    ####################################################
    # Data setting
    ####################################################

    prep_dict = {}  # 데이터 전처리 transforms 부분
    prep_dict["source"] = prep.image_train(**config['prep']['params'])
    prep_dict["target"] = prep.image_train(**config["prep"]['params'])
    prep_dict["test"] = prep.image_test(**config['prep']['params'])

    dsets = {}
    dsets["source"] = datasets.ImageFolder(config['s_dset_path'],
                                           transform=prep_dict["source"])
    dsets["target"] = datasets.ImageFolder(config['t_dset_path'],
                                           transform=prep_dict['target'])
    dsets['test'] = datasets.ImageFolder(config['t_dset_path'],
                                         transform=prep_dict['test'])

    data_config = config["data"]
    train_source_bs = data_config["source"][
        "batch_size"]  #원본은 source와 target 모두 source train bs로 설정되었는데 이를 수정함
    train_target_bs = data_config['target']['batch_size']
    test_bs = data_config["test"]["batch_size"]

    dset_loaders = {}
    dset_loaders["source"] = DataLoader(
        dsets["source"],
        batch_size=train_source_bs,
        shuffle=True,
        num_workers=4,
        drop_last=True
    )  # 원본은 drop_last=True, 이렇게 해야 마지막까지 source, target에서 동일한 수로 배치 생성가능
    dset_loaders["target"] = DataLoader(dsets["target"],
                                        batch_size=train_target_bs,
                                        shuffle=True,
                                        num_workers=4,
                                        drop_last=True)
    dset_loaders['test'] = DataLoader(dsets['test'],
                                      batch_size=test_bs,
                                      shuffle=False,
                                      num_workers=4,
                                      drop_last=False)

    ####################################################
    # Network Setting
    ####################################################

    class_num = config["network"]['params']['class_num']

    net_config = config["network"]
    """
        config['network'] = {'name': network.ResNetFc,
                         'params': {'resnet_name': args.net,
                                    'use_bottleneck': True,
                                    'bottleneck_dim': 256,
                                    'new_cls': True,
                                    'class_num': args.class_num,
                                    'type' : args.type}
                         }
    """

    base_network = net_config["name"](**net_config["params"])
    #network.py에 정의된 ResNetFc() 클래스 호출
    base_network = base_network.cuda()  # ResNetFc(Resnet, True, 256, True, 12)

    if config["loss"]["random"]:
        random_layer = network.RandomLayer(
            [base_network.output_num(), class_num],
            config["loss"]["random_dim"])
        random_layer.cuda()
        ad_net = network.AdversarialNetwork(config["loss"]["random_dim"], 1024)
    else:
        random_layer = None
        ad_net = network.AdversarialNetwork(
            base_network.output_num() * class_num, 1024)  # 왜 class 수 만큼 곱하지?

    ad_net = ad_net.cuda()

    parameter_list = base_network.get_parameters() + ad_net.get_parameters()

    ####################################################
    # Env Setting
    ####################################################

    #gpus = config['gpu'].split(',')
    #if len(gpus) > 1 :
    #ad_net = nn.DataParallel(ad_net, device_ids=[int(i) for i in gpus])
    #base_network = nn.DataParallel(base_network, device_ids=[int(i) for i in gpus])

    ####################################################
    # Optimizer Setting
    ####################################################

    optimizer_config = config['optimizer']
    optimizer = optimizer_config["type"](parameter_list,
                                         **(optimizer_config["optim_params"]))
    # optim.SGD

    #config['optimizer'] = {'type': optim.SGD,
    #'optim_params': {'lr': args.lr,
    #'momentum': 0.9,
    #'weight_decay': 0.0005,
    #'nestrov': True},
    #'lr_type': "inv",
    #'lr_param': {"lr": args.lr,
    #'gamma': 0.001, # 이거 0.01이여야 하지 않나?
    #'power': 0.75
    #}

    param_lr = []
    for param_group in optimizer.param_groups:
        param_lr.append(param_group['lr'])

    schedule_param = optimizer_config['lr_param']

    lr_scheduler = lr_schedule.schedule_dict[
        optimizer_config["lr_type"]]  # return optimizer

    ####################################################
    # Train
    ####################################################

    len_train_source = len(dset_loaders["source"])
    len_train_target = len(dset_loaders["target"])

    transfer_loss_value = 0.0
    classifier_loss_value = 0.0
    total_loss_value = 0.0

    best_acc = 0.0

    batch_size = config["data"]["source"]["batch_size"]

    for i in range(
            config["num_iterations"]):  # num_iterations수의 batch가 학습에 사용됨
        sys.stdout.write("Iteration : {} \r".format(i))
        sys.stdout.flush()

        loss_params = config["loss"]

        base_network.train(True)
        ad_net.train(True)

        optimizer = lr_scheduler(optimizer, i, **schedule_param)
        optimizer.zero_grad()

        if i % len_train_source == 0:
            iter_source = iter(dset_loaders["source"])
        if i % len_train_target == 0:
            iter_target = iter(dset_loaders["target"])

        inputs_source, labels_source = iter_source.next()
        inputs_target, labels_target = iter_target.next()

        inputs_source, labels_source = inputs_source.cuda(
        ), labels_source.cuda()
        inputs_target = inputs_target.cuda()

        inputs = torch.cat((inputs_source, inputs_target), dim=0)

        features, outputs, tau, cur_mean_source, cur_mean_target, output_mean_source, output_mean_target = base_network(
            inputs)

        softmax_out = nn.Softmax(dim=1)(outputs)

        outputs_source = outputs[:batch_size]
        outputs_target = outputs[batch_size:]

        if config['method'] == 'CDAN+E' or config['method'] == 'CDAN_TransNorm':
            entropy = loss.Entropy(softmax_out)
            transfer_loss = loss.CDAN([features, softmax_out], ad_net, entropy,
                                      network.calc_coeff(i), random_layer)
        elif config['method'] == 'CDAN':
            transfer_loss = loss.CDAN([features, softmax_out], ad_net, None,
                                      None, random_layer)
        elif config['method'] == 'DANN':
            pass  # 나중에 정리하기
        else:
            raise ValueError('Method cannot be recognized')

        classifier_loss = nn.CrossEntropyLoss()(outputs_source, labels_source)
        total_loss = loss_params["trade_off"] * transfer_loss + classifier_loss

        total_loss.backward()
        optimizer.step()

        #tensor_writer.add_scalar('total_loss', total_loss.i )
        #tensor_writer.add_scalar('classifier_loss', classifier_loss, i)
        #tensor_writer.add_scalar('transfer_loss', transfer_loss, i)

        ####################################################
        # Test
        ####################################################
        if i % config["test_interval"] == config["test_interval"] - 1:
            # test interval 마다
            base_network.train(False)
            temp_acc = image_classification_test(dset_loaders, base_network)
            temp_model = nn.Sequential(base_network)
            if temp_acc > best_acc:
                best_acc = temp_acc
                best_model = temp_model
                ACC = round(best_acc, 2) * 100
                torch.save(
                    best_model,
                    os.path.join(config["output_path"],
                                 "iter_{}_model.pth.tar".format(ACC)))
            log_str = "iter: {:05d}, precision: {:.5f}".format(i, temp_acc)
            config["out_file"].write(log_str + "\n")
            config["out_file"].flush()
            print(log_str)
Пример #5
0
def main():
    parser = argparse.ArgumentParser(description='CDAN SVHN MNIST')
    parser.add_argument('--method',
                        type=str,
                        default='CDAN-E',
                        choices=['CDAN', 'CDAN-E', 'DANN'])
    parser.add_argument('--task', default='USPS2MNIST', help='task to perform')
    parser.add_argument('--batch_size',
                        type=int,
                        default=256,
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test_batch_size',
                        type=int,
                        default=1000,
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=100,
                        metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--lr', type=float, default=0.03, metavar='LR')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.5,
                        metavar='M',
                        help='SGD momentum (default: 0.5)')
    parser.add_argument('--gpu_id',
                        default='0',
                        type=str,
                        help='cuda device id')
    parser.add_argument('--seed',
                        type=int,
                        default=40,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument(
        '--log_interval',
        type=int,
        default=10,
        help='how many batches to wait before logging training status')
    parser.add_argument('--random',
                        type=bool,
                        default=False,
                        help='whether to use random')
    parser.add_argument("--mdd_weight", type=float, default=0)
    parser.add_argument("--entropic_weight", type=float, default=0)
    parser.add_argument("--weight", type=float, default=1)
    parser.add_argument("--left_weight", type=float, default=1)
    parser.add_argument("--right_weight", type=float, default=1)
    parser.add_argument('--use_seed', type=int, default=1)
    args = parser.parse_args()
    if args.use_seed:
        import random
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)
        random.seed(args.seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

    import os.path as osp
    import time
    config = {}

    config["use_seed"] = args.use_seed
    config['seed'] = args.seed
    config["output_path"] = "snapshot/s2m"
    config["mdd_weight"] = args.mdd_weight
    config["entropic_weight"] = args.entropic_weight
    config["weight"] = args.weight
    config["left_weight"] = args.left_weight
    config["right_weight"] = args.right_weight
    if not osp.exists(config["output_path"]):
        os.system('mkdir -p ' + config["output_path"])
    config["out_file"] = open(
        osp.join(
            config["output_path"],
            "log_svhn_to_mnist_{}______{}.txt".format(str(int(time.time())),
                                                      str(args.seed))), "w")

    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id

    source_list = 'data/svhn2mnist/svhn_balanced.txt'
    target_list = 'data/svhn2mnist/mnist_train.txt'
    test_list = 'data/svhn2mnist/mnist_test.txt'

    train_loader = torch.utils.data.DataLoader(ImageList(
        open(source_list).readlines(),
        transform=transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.5, ), (0.5, ))]),
        mode='RGB'),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=0)
    train_loader1 = torch.utils.data.DataLoader(ImageList(
        open(target_list).readlines(),
        transform=transforms.Compose([
            transforms.Resize((32, 32)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, ), (0.5, ))
        ]),
        mode='RGB'),
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                num_workers=0)
    test_loader = torch.utils.data.DataLoader(ImageList(
        open(test_list).readlines(),
        transform=transforms.Compose([
            transforms.Resize((32, 32)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, ), (0.5, ))
        ]),
        mode='RGB'),
                                              batch_size=args.test_batch_size,
                                              shuffle=True,
                                              num_workers=0)

    model = network.DTN()
    model = model.cuda()
    class_num = 10

    if args.random:
        random_layer = network.RandomLayer([model.output_num(), class_num],
                                           500)
        ad_net = network.AdversarialNetwork(500, 500)
        random_layer.cuda()
    else:
        random_layer = None
        ad_net = network.AdversarialNetwork(model.output_num() * class_num,
                                            500)
    ad_net = ad_net.cuda()
    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          weight_decay=0.0005,
                          momentum=0.9)
    optimizer_ad = optim.SGD(ad_net.parameters(),
                             lr=args.lr,
                             weight_decay=0.0005,
                             momentum=0.9)

    config["out_file"].write(str(config))
    config["out_file"].flush()
    best_model = model
    best_acc = 0

    for epoch in range(1, args.epochs + 1):
        if epoch % 3 == 0:
            for param_group in optimizer.param_groups:
                param_group["lr"] = param_group["lr"] * 0.3

        train(args, config, model, ad_net, random_layer, train_loader,
              train_loader1, optimizer, optimizer_ad, epoch)
        acc = test(epoch, config, model, test_loader)
        if (acc > best_acc):
            best_model = model
            best_acc = acc

    torch.save(
        best_model,
        osp.join("snapshot/s2m_model",
                 "s2m_{}_{}".format(str(best_acc), str(args.mdd_weight))))
Пример #6
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='CDAN USPS MNIST')
    parser.add_argument('--method',
                        type=str,
                        default='CDAN-E',
                        choices=['CDAN', 'CDAN-E', 'DANN'])
    parser.add_argument('--task', default='USPS2MNIST', help='task to perform')
    parser.add_argument('--batch_size',
                        type=int,
                        default=128,
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test_batch_size',
                        type=int,
                        default=1000,
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=550,
                        metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--lr',
                        type=float,
                        default=5e-5,
                        metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--lr2',
                        type=float,
                        default=0.005,
                        metavar='LR2',
                        help='learning rate2 (default: 0.01)')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.5,
                        metavar='M',
                        help='SGD momentum (default: 0.5)')
    parser.add_argument('--gpu_id',
                        type=str,
                        default='0',
                        help='cuda device id')
    parser.add_argument(
        '--log_interval',
        type=int,
        default=10,
        help='how many batches to wait before logging training status')
    parser.add_argument('--random',
                        type=bool,
                        default=False,
                        help='whether to use random')
    args = parser.parse_args()

    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id

    if args.task == 'USPS2MNIST':
        source_list, ordinary_train_dataset, target_list, test_list, ccp = data_loader(
            task='U2M')
        start_epoch = 50
        decay_epoch = 600
    elif args.task == 'MNIST2USPS':
        source_list, ordinary_train_dataset, target_list, test_list, ccp = data_loader(
            task='M2U')
        start_epoch = 50
        decay_epoch = 600
    else:
        raise Exception('task cannot be recognized!')

    train_loader = torch.utils.data.DataLoader(dataset=source_list,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=8,
                                               drop_last=True)
    train_loader1 = torch.utils.data.DataLoader(dataset=target_list,
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                num_workers=8,
                                                drop_last=True)
    o_train_loader = torch.utils.data.DataLoader(
        dataset=ordinary_train_dataset,
        batch_size=args.test_batch_size,
        shuffle=True,
        num_workers=8)
    test_loader = torch.utils.data.DataLoader(dataset=test_list,
                                              batch_size=args.test_batch_size,
                                              shuffle=True,
                                              num_workers=8)

    model = network.LeNet()
    model = model.cuda()
    class_num = 10

    if args.random:
        random_layer = network.RandomLayer([model.output_num(), class_num],
                                           500)
        ad_net = network.AdversarialNetwork(500, 500)
        random_layer.cuda()
    else:
        random_layer = None
        ad_net = network.AdversarialNetwork(model.output_num() * class_num,
                                            500)
    ad_net = ad_net.cuda()

    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          weight_decay=0.0005,
                          momentum=0.9)
    optimizer_ad = optim.SGD(ad_net.parameters(),
                             lr=args.lr2,
                             weight_decay=0.0005,
                             momentum=0.9)

    save_table = np.zeros(shape=(args.epochs, 3))
    for epoch in range(1, args.epochs + 1):
        if epoch % decay_epoch == 0:
            for param_group in optimizer.param_groups:
                param_group["lr"] = param_group["lr"] * 0.5
        train(args, model, ad_net, random_layer, train_loader, train_loader1,
              optimizer, optimizer_ad, epoch, start_epoch, args.method, ccp)
        acc1 = test(args, model, o_train_loader)
        acc2 = test(args, model, test_loader)
        save_table[epoch - 1, :] = epoch - 50, acc1, acc2
        np.savetxt(args.task + '_.txt', save_table, delimiter=',', fmt='%1.3f')
    np.savetxt(args.task + '_.txt', save_table, delimiter=',', fmt='%1.3f')
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='CDAN SVHN MNIST')
    parser.add_argument('--method',
                        type=str,
                        default='CDAN',
                        choices=['CDAN', 'CDAN-E', 'DANN'])
    parser.add_argument('--task', default='USPS2MNIST', help='task to perform')
    parser.add_argument('--batch_size',
                        type=int,
                        default=64,
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test_batch_size',
                        type=int,
                        default=1000,
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=10,
                        metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--lr', type=float, default=0.03, metavar='LR')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.5,
                        metavar='M',
                        help='SGD momentum (default: 0.5)')
    parser.add_argument('--gpu_id',
                        type=str,
                        default='0',
                        help='cuda device id')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument(
        '--log_interval',
        type=int,
        default=10,
        help='how many batches to wait before logging training status')
    parser.add_argument('--random',
                        type=bool,
                        default=False,
                        help='whether to use random')
    args = parser.parse_args()

    torch.manual_seed(args.seed)
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id

    source_list = '/data/svhn/svhn_train.txt'
    target_list = '/data/mnist/mnist_train.txt'
    test_list = '/data/mnist/mnist_test.txt'

    train_loader = torch.utils.data.DataLoader(ImageList(
        open(source_list).readlines(),
        transform=transforms.Compose([
            transforms.Resize(28),
            transforms.ToTensor(),
            transforms.Normalize((0.5, ), (0.5, ))
        ]),
        mode='RGB'),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=1)
    train_loader1 = torch.utils.data.DataLoader(ImageList(
        open(target_list).readlines(),
        transform=transforms.Compose([
            transforms.Resize(28),
            transforms.ToTensor(),
            transforms.Normalize((0.5, ), (0.5, ))
        ]),
        mode='RGB'),
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                num_workers=1)
    test_loader = torch.utils.data.DataLoader(ImageList(
        open(test_list).readlines(),
        transform=transforms.Compose([
            transforms.Resize(28),
            transforms.ToTensor(),
            transforms.Normalize((0.5, ), (0.5, ))
        ]),
        mode='RGB'),
                                              batch_size=args.test_batch_size,
                                              shuffle=True,
                                              num_workers=1)

    # img_transform_for_svhn = transforms.Compose([
    #     transforms.Resize(28),
    #     transforms.ToTensor(),
    #     transforms.Normalize(mean=(0.1307,), std=(0.3081,))
    #     # transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    #     # transforms.Normalize((0.5,), (0.5,))
    # ])
    # img_transform_for_mnist = transforms.Compose([
    #     transforms.Resize(28),
    #     transforms.ToTensor(),
    #     # transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    #     transforms.Normalize(mean=(0.1307,), std=(0.3081,))
    #     # transforms.Normalize((0.5,), (0.5,))
    # ])

    # dataset_source = datasets.SVHN(
    #     root=os.path.abspath(os.path.join(os.path.dirname(__file__), '../data/svhn2mnist')),
    #     split='train',
    #     transform=img_transform_for_svhn,
    #     download=True
    # )
    # train_loader = torch.utils.data.DataLoader(
    #     dataset=dataset_source,
    #     batch_size=args.batch_size,
    #     shuffle=True,
    #     num_workers=0)

    # dataset_target = datasets.MNIST(
    #     root=os.path.abspath(os.path.join(os.path.dirname(__file__), '../data/svhn2mnist')),
    #     train=True,
    #     transform=img_transform_for_mnist,
    #     download=True
    # )
    # train_loader1 = torch.utils.data.DataLoader(
    #     dataset=dataset_target,
    #     batch_size=args.batch_size,
    #     shuffle=True,
    #     num_workers=0)

    # dataset_test = datasets.MNIST(
    #     root=os.path.abspath(os.path.join(os.path.dirname(__file__), '../data/svhn2mnist')),
    #     train=False,
    #     transform=img_transform_for_mnist,
    #     download=True
    # )
    # test_loader = torch.utils.data.DataLoader(
    #     dataset=dataset_test,
    #     batch_size=args.batch_size,
    #     shuffle=True,
    #     num_workers=0)

    model = network.DTN()
    # model = model.cuda()
    class_num = 10

    if args.random:
        random_layer = network.RandomLayer([model.output_num(), class_num],
                                           500)
        ad_net = network.AdversarialNetwork(500, 500)
        # random_layer.cuda()
    else:
        random_layer = None
        ad_net = network.AdversarialNetwork(model.output_num() * class_num,
                                            500)
    # ad_net = ad_net.cuda()
    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          weight_decay=0.0005,
                          momentum=0.9)
    optimizer_ad = optim.SGD(ad_net.parameters(),
                             lr=args.lr,
                             weight_decay=0.0005,
                             momentum=0.9)

    for epoch in range(1, args.epochs + 1):
        if epoch % 3 == 0:
            for param_group in optimizer.param_groups:
                param_group["lr"] = param_group["lr"] * 0.3
        train(args, model, ad_net, random_layer, train_loader, train_loader1,
              optimizer, optimizer_ad, epoch, 0, args.method)
        test(args, model, test_loader)
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='CDAN USPS MNIST')
    parser.add_argument('--method', type=str, default='CDAN-E', choices=['CDAN', 'CDAN-E', 'DANN'])
    parser.add_argument('--task', default='USPS2MNIST', help='task to perform')
    parser.add_argument('--batch_size', type=int, default=256, help='input batch size for training (default: 64)')
    parser.add_argument('--test_batch_size', type=int, default=1000, help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=10, metavar='N', help='number of epochs to train (default: 10)')
    parser.add_argument('--lr', type=float, default=0.01, metavar='LR', help='learning rate (default: 0.01)')
    parser.add_argument('--momentum', type=float, default=0.5, metavar='M', help='SGD momentum (default: 0.5)')
    parser.add_argument('--gpu_id', type=str,default="0",  help='cuda device id')
    parser.add_argument('--seed', type=int, default=1, metavar='S',  help='random seed (default: 1)')
    parser.add_argument('--log_interval', type=int, default=10, help='how many batches to wait before logging training status')
    parser.add_argument('--random', type=bool, default=False, help='whether to use random')
    parser.add_argument('--output_dir',type=str,default="digits/u2m")
    parser.add_argument('--cla_plus_weight',type=float,default=0.3)
    parser.add_argument('--cyc_loss_weight',type=float,default=0.01)
    parser.add_argument('--weight_in_loss_g',type=str,default='1,0.01,0.1,0.1')
    args = parser.parse_args()

    torch.manual_seed(args.seed)
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id

    # train config
    import os.path as osp
    import datetime
    config = {}

    config['method'] = args.method
    config["gpu"] = args.gpu_id
    config['cyc_loss_weight'] = args.cyc_loss_weight
    config['cla_plus_weight'] = args.cla_plus_weight
    config['weight_in_loss_g'] = args.weight_in_loss_g
    config["epochs"] = args.epochs
    config["output_for_test"] = True
    config["output_path"] = "snapshot/" + args.output_dir
    if not osp.exists(config["output_path"]):
        os.system('mkdir -p ' + config["output_path"])
    config["out_file"] = open(osp.join(config["output_path"], "log_{}_{}.txt".
                                       format(args.task,str(datetime.datetime.utcnow()))),
                              "w")

    config["out_file"].write(str(config))
    config["out_file"].flush()


    source_list = '/data/usps/usps_train.txt'
    target_list = '/data/mnist/mnist_train.txt'
    test_list = '/data/mnist/mnist_test.txt'
    start_epoch = 1
    decay_epoch = 6


    train_loader = torch.utils.data.DataLoader(
        ImageList(open(source_list).readlines(), transform=transforms.Compose([
                           transforms.Resize((28,28)),
                           transforms.ToTensor(),
                           transforms.Normalize((0.5,), (0.5,))
                       ]), mode='L'),
        batch_size=args.batch_size, shuffle=True, num_workers=1, drop_last=True)
    train_loader1 = torch.utils.data.DataLoader(
        ImageList(open(target_list).readlines(), transform=transforms.Compose([
                           transforms.Resize((28,28)),
                           transforms.ToTensor(),
                           transforms.Normalize((0.5,), (0.5,))
                       ]), mode='L'),
        batch_size=args.batch_size, shuffle=True, num_workers=1, drop_last=True)
    test_loader = torch.utils.data.DataLoader(
        ImageList(open(test_list).readlines(), transform=transforms.Compose([
                           transforms.Resize((28,28)),
                           transforms.ToTensor(),
                           transforms.Normalize((0.5,), (0.5,))
                       ]), mode='L'),
        batch_size=args.test_batch_size, shuffle=True, num_workers=1)

    model = network.LeNet()
    # model = model.cuda()
    class_num = 10

    # 添加G,D,和额外的分类器
    import itertools
    from utils import ReplayBuffer
    import net
    z_dimension = 500
    D_s = network.models["Discriminator_um"]()
    # D_s = D_s.cuda()
    G_s2t = network.models["Generator_um"](z_dimension, 500)
    # G_s2t = G_s2t.cuda()

    D_t = network.models["Discriminator_um"]()
    # D_t = D_t.cuda()
    G_t2s = network.models["Generator_um"](z_dimension, 500)
    # G_t2s = G_t2s.cuda()

    criterion_GAN = torch.nn.MSELoss()
    criterion_cycle = torch.nn.L1Loss()
    criterion_identity = torch.nn.L1Loss()
    criterion_Sem = torch.nn.L1Loss()

    optimizer_G = torch.optim.Adam(itertools.chain(G_s2t.parameters(), G_t2s.parameters()), lr=0.0003)
    optimizer_D_s = torch.optim.Adam(D_s.parameters(), lr=0.0003)
    optimizer_D_t = torch.optim.Adam(D_t.parameters(), lr=0.0003)

    fake_S_buffer = ReplayBuffer()
    fake_T_buffer = ReplayBuffer()

    ## 添加分类器
    classifier1 = net.Net(500, class_num)
    # classifier1 = classifier1.cuda()
    classifier1_optim = optim.Adam(classifier1.parameters(), lr=0.0003)


    if args.random:
        random_layer = network.RandomLayer([model.output_num(), class_num], 500)
        ad_net = network.AdversarialNetwork(500, 500)
        # random_layer.cuda()
    else:
        random_layer = None
        ad_net = network.AdversarialNetwork(model.output_num() * class_num, 500)
    # ad_net = ad_net.cuda()
    optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=0.0005, momentum=0.9)
    optimizer_ad = optim.SGD(ad_net.parameters(), lr=args.lr, weight_decay=0.0005, momentum=0.9)

    for epoch in range(1, args.epochs + 1):
        if epoch % decay_epoch == 0:
            for param_group in optimizer.param_groups:
                param_group["lr"] = param_group["lr"] * 0.5
        train(args, model, ad_net, random_layer, train_loader, train_loader1, optimizer, optimizer_ad, epoch, start_epoch, args.method,
              D_s, D_t, G_s2t, G_t2s, criterion_Sem, criterion_GAN, criterion_cycle, criterion_identity, optimizer_G,
              optimizer_D_t, optimizer_D_s,
              classifier1, classifier1_optim, fake_S_buffer, fake_T_buffer
              )
        test(args,epoch,config, model, test_loader)
Пример #9
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='CDAN SVHN MNIST')
    parser.add_argument('--method',
                        type=str,
                        default='CDAN-E',
                        choices=['CDAN', 'CDAN-E', 'DANN'])
    parser.add_argument('--task', default='USPS2MNIST', help='task to perform')
    parser.add_argument('--batch_size',
                        type=int,
                        default=64,
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test_batch_size',
                        type=int,
                        default=1000,
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=10,
                        metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--lr', type=float, default=0.03, metavar='LR')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.5,
                        metavar='M',
                        help='SGD momentum (default: 0.5)')
    parser.add_argument('--gpu_id', type=str, help='cuda device id')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument(
        '--log_interval',
        type=int,
        default=10,
        help='how many batches to wait before logging training status')
    parser.add_argument('--random',
                        type=bool,
                        default=False,
                        help='whether to use random')
    args = parser.parse_args()

    torch.manual_seed(args.seed)
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id

    source_list = '../data/svhn2mnist/svhn_balanced.txt'
    target_list = '../data/svhn2mnist/mnist_train.txt'
    test_list = '../data/svhn2mnist/mnist_test.txt'

    train_loader = torch.utils.data.DataLoader(ImageList(
        open(source_list).readlines(),
        transform=transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.5, ), (0.5, ))]),
        mode='RGB'),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=1)
    train_loader1 = torch.utils.data.DataLoader(ImageList(
        open(target_list).readlines(),
        transform=transforms.Compose([
            transforms.Resize((32, 32)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, ), (0.5, ))
        ]),
        mode='RGB'),
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                num_workers=1)
    test_loader = torch.utils.data.DataLoader(ImageList(
        open(test_list).readlines(),
        transform=transforms.Compose([
            transforms.Resize((32, 32)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, ), (0.5, ))
        ]),
        mode='RGB'),
                                              batch_size=args.test_batch_size,
                                              shuffle=True,
                                              num_workers=1)

    model = network.DTN()
    model = model.cuda()
    class_num = 10

    if args.random:
        random_layer = network.RandomLayer([model.output_num(), class_num],
                                           500)
        ad_net = network.AdversarialNetwork(500, 500)
        random_layer.cuda()
    else:
        random_layer = None
        print('aaa:', model.output_num(), class_num)
        ad_net = network.AdversarialNetwork(model.output_num() * class_num,
                                            500)
    ad_net = ad_net.cuda()
    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          weight_decay=0.0005,
                          momentum=0.9)
    optimizer_ad = optim.SGD(ad_net.parameters(),
                             lr=args.lr,
                             weight_decay=0.0005,
                             momentum=0.9)

    for epoch in range(1, args.epochs + 1):
        if epoch % 3 == 0:
            for param_group in optimizer.param_groups:
                param_group["lr"] = param_group["lr"] * 0.3
        train(args, model, ad_net, random_layer, train_loader, train_loader1,
              optimizer, optimizer_ad, epoch, 0, args.method)
        acc = test(args, model, test_loader)
        with summary_writer.as_default():
            tf.summary.scalar("acc", acc, step=epoch)
Пример #10
0
def train(config):
    ## set pre-process
    prep_dict = {}
    prep_config = config["prep"]
    prep_dict["source"] = prep.image_train(**config["prep"]['params'])
    prep_dict["target"] = prep.image_train(**config["prep"]['params'])
    if prep_config["test_10crop"]:
        prep_dict["test"] = prep.image_test_10crop(**config["prep"]['params'])
    else:
        prep_dict["test"] = prep.image_test(**config["prep"]['params'])

    tensor_writer = SummaryWriter(config["tensorboard_path"])

    ## prepare data
    dsets = {}
    dset_loaders = {}
    data_config = config["data"]
    train_bs = data_config["source"]["batch_size"]
    test_bs = data_config["test"]["batch_size"]
    dsets["source"] = ImageList(open(data_config["source"]["list_path"]).readlines(), \
                                transform=prep_dict["source"])
    dset_loaders["source"] = DataLoader(dsets["source"], batch_size=train_bs, \
            shuffle=True, num_workers=4, drop_last=True)

    dsets["target"] = ImageList(open(data_config["target"]["list_path"]).readlines(), \
                                transform=prep_dict["target"])
    dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, \
            shuffle=True, num_workers=4, drop_last=True)

    if prep_config["test_10crop"]:
        for i in range(10):
            dsets["test"] = [ImageList(open(data_config["test"]["list_path"]).readlines(), \
                                transform=prep_dict["test"][i]) for i in range(10)]
            dset_loaders["test"] = [DataLoader(dset, batch_size=test_bs, \
                                shuffle=False, num_workers=4) for dset in dsets['test']]
            dsets["source_val"] = [ImageList(open(data_config["source"]["list_path"]).readlines(), \
                                transform=prep_dict["test"][i]) for i in range(10)]
            dset_loaders["source_val"] = [DataLoader(dset, batch_size=test_bs, \
                                shuffle=False, num_workers=4) for dset in dsets['source_val']]
    else:
        dsets["test"] = ImageList(open(data_config["test"]["list_path"]).readlines(), \
                                transform=prep_dict["test"])
        dset_loaders["test"] = DataLoader(dsets["test"], batch_size=test_bs, \
                                shuffle=False, num_workers=4)

    class_num = config["network"]["params"]["class_num"]

    ## set base network
    net_config = config["network"]
    base_network = net_config["name"](**net_config["params"])
    base_network = base_network.cuda()

    ## add additional network for some methods
    if config["loss"]["random"]:
        random_layer = network.RandomLayer(
            [base_network.output_num(), class_num],
            config["loss"]["random_dim"])
        ad_net = network.AdversarialNetwork(config["loss"]["random_dim"], 1024)
    else:
        random_layer = None
        # ad_net = network.AdversarialNetwork(base_network.output_num() * class_num, 1024)
        ad_net = network.AdversarialNetwork(base_network.output_num(), 1024)
    if config["loss"]["random"]:
        random_layer.cuda()
    ad_net = ad_net.cuda()
    parameter_list = base_network.get_parameters() + ad_net.get_parameters()

    ## set optimizer
    optimizer_config = config["optimizer"]
    optimizer = optimizer_config["type"](parameter_list, \
                    **(optimizer_config["optim_params"]))
    param_lr = []
    for param_group in optimizer.param_groups:
        param_lr.append(param_group["lr"])
    schedule_param = optimizer_config["lr_param"]
    lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]]

    gpus = config['gpu'].split(',')
    if len(gpus) > 1:
        ad_net = nn.DataParallel(ad_net, device_ids=[int(i) for i in gpus])
        base_network = nn.DataParallel(base_network,
                                       device_ids=[int(i) for i in gpus])

    ## train
    len_train_source = len(dset_loaders["source"])
    len_train_target = len(dset_loaders["target"])
    transfer_loss_value = classifier_loss_value = total_loss_value = 0.0
    best_acc = 0.0
    for i in range(config["num_iterations"]):
        if i % config["test_interval"] == config["test_interval"] - 1:
            base_network.train(False)
            temp_acc, output, prediction, label, feature = image_classification_test(dset_loaders, \
                base_network, test_10crop=prep_config["test_10crop"])
            _, output_src, prediction_src, label_src, feature_src = image_classification_val(dset_loaders, \
                base_network, test_10crop=prep_config["test_10crop"])
            temp_model = nn.Sequential(base_network)
            if temp_acc > best_acc:
                best_acc = temp_acc
                best_model = temp_model
        if i % config["snapshot_interval"] == 0:
            torch.save(nn.Sequential(base_network), osp.join(config["output_path"], \
                "iter_{:05d}_model.pth.tar".format(i)))

        loss_params = config["loss"]
        ## train one iter
        base_network.train(True)
        ad_net.train(True)
        optimizer = lr_scheduler(optimizer, i, **schedule_param)
        optimizer.zero_grad()
        if i % len_train_source == 0:
            iter_source = iter(dset_loaders["source"])
        if i % len_train_target == 0:
            iter_target = iter(dset_loaders["target"])
        inputs_source, labels_source = iter_source.next()
        inputs_target, labels_target = iter_target.next()
        inputs_source, inputs_target, labels_source = inputs_source.cuda(
        ), inputs_target.cuda(), labels_source.cuda()
        features_source, outputs_source = base_network(inputs_source)
        features_target, outputs_target = base_network(inputs_target)
        features = torch.cat((features_source, features_target), dim=0)
        outputs = torch.cat((outputs_source, outputs_target), dim=0)
        softmax_out = nn.Softmax(dim=1)(outputs)

        target_softmax_out = nn.Softmax(dim=1)(outputs_target)
        target_entropy = EntropyLoss(target_softmax_out)

        temperature = 3.0
        outputs_target_temp = outputs_target / temperature
        target_softmax_out_temp = nn.Softmax(dim=1)(outputs_target_temp)
        target_entropy_weight = loss.Entropy(target_softmax_out_temp).detach()
        target_entropy_weight = 1 + torch.exp(-target_entropy_weight)
        target_entropy_weight = train_bs * target_entropy_weight / torch.sum(
            target_entropy_weight)

        cov_matrix_t_temp = target_softmax_out_temp.mul(
            target_entropy_weight.view(-1, 1)).transpose(
                1, 0).mm(target_softmax_out_temp)
        cov_matrix_t_temp = cov_matrix_t_temp / torch.sum(cov_matrix_t_temp,
                                                          dim=1)

        mcc_loss = (torch.sum(cov_matrix_t_temp) -
                    torch.trace(cov_matrix_t_temp)) / class_num

        classifier_loss = nn.CrossEntropyLoss()(outputs_source, labels_source)
        total_loss = classifier_loss + mcc_loss
        total_loss.backward()
        optimizer.step()

        tensor_writer.add_scalar('total_loss', total_loss, i)
        tensor_writer.add_scalar('classifier_loss', classifier_loss, i)
        tensor_writer.add_scalar('cov_matrix_penalty', mcc_loss, i)

    torch.save(best_model, osp.join(config["output_path"],
                                    "best_model.pth.tar"))
    return best_acc
Пример #11
0
def train(config):
    ## set pre-process
    prep_dict = {}
    prep_config = config["prep"]
    prep_dict["source"] = prep.image_train(**config["prep"]['params'])
    prep_dict["target"] = prep.image_train(**config["prep"]['params'])
    if prep_config["test_10crop"]:
        prep_dict["test"] = prep.image_test_10crop(**config["prep"]['params'])
    else:
        prep_dict["test"] = prep.image_test(**config["prep"]['params'])

    ## prepare data
    dsets = {}
    dset_loaders = {}
    data_config = config["data"]
    train_bs = data_config["source"]["batch_size"]
    test_bs = data_config["test"]["batch_size"]
    dsets["source"] = ImageList(open(data_config["source"]["list_path"]).readlines(), \
                                transform=prep_dict["source"])
    dset_loaders["source"] = DataLoader(dsets["source"], batch_size=train_bs, \
                                        shuffle=True, num_workers=4, drop_last=True)
    dsets["target"] = ImageList(open(data_config["target"]["list_path"]).readlines(), \
                                transform=prep_dict["target"])
    dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, \
                                        shuffle=True, num_workers=4, drop_last=True)

    if prep_config["test_10crop"]:
        for i in range(10):
            dsets["test"] = [ImageList(open(data_config["test"]["list_path"]).readlines(), \
                                       transform=prep_dict["test"][i]) for i in range(10)]
            dset_loaders["test"] = [DataLoader(dset, batch_size=test_bs, \
                                               shuffle=False, num_workers=4) for dset in dsets['test']]
    else:
        dsets["test"] = ImageList(open(data_config["test"]["list_path"]).readlines(), \
                                  transform=prep_dict["test"])
        dset_loaders["test"] = DataLoader(dsets["test"], batch_size=test_bs, \
                                          shuffle=False, num_workers=4)

    class_num = config["network"]["params"]["class_num"]

    ## set base network
    net_config = config["network"]
    base_network = net_config["name"](**net_config["params"])
    base_network = base_network.cuda()


    with torch.no_grad():
        cluster_data_loader = {}
        cluster_data_loader["source"] = DataLoader(dsets["source"], batch_size=100, \
                                                   shuffle=True, num_workers=0, drop_last=True)
        cluster_data_loader["target"] = DataLoader(dsets["source"], batch_size=100, \
                                                   shuffle=True, num_workers=0, drop_last=True)


    ## add additional network for some methods




    if config["loss"]["random"]:
        random_layer = network.RandomLayer([base_network.output_num(), class_num], config["loss"]["random_dim"])
        ad_net = network.AdversarialNetwork(config["loss"]["random_dim"], 1024)
    else:
        random_layer = None
        ad_net = network.AdversarialNetwork(base_network.output_num() * class_num, 1024)
    if config["loss"]["random"]:
        random_layer.cuda()
    ad_net = ad_net.cuda()
    parameter_list = base_network.get_parameters() + ad_net.get_parameters()

    ## set optimizer
    optimizer_config = config["optimizer"]
    optimizer = optimizer_config["type"](parameter_list, \
                                         **(optimizer_config["optim_params"]))
    param_lr = []
    for param_group in optimizer.param_groups:
        param_lr.append(param_group["lr"])
    schedule_param = optimizer_config["lr_param"]
    lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]]

    gpus = config['gpu'].split(',')
    if len(gpus) > 1:
        ad_net = nn.DataParallel(ad_net, device_ids=[int(i) for i in gpus])
        base_network = nn.DataParallel(base_network, device_ids=[int(i) for i in gpus])

    # dset_loaders["ps_target"]=[]
    ## train
    len_train_source = len(dset_loaders["source"])
    # len_train_target = len(dset_loaders["ps_target"])
    transfer_loss_value = classifier_loss_value = total_loss_value = 0.0
    best_acc = 0.0
    for i in range(config["num_iterations"]):
        lamb = adaptation_factor((i+1)/10000)
        cls_lamb = adaptation_factor(5*(i+1)/10000)
        epoch = int(i / len_train_source)
        if i% len_train_source ==0:
            testing = True
            pl_update=True
            print_loss =True
            # print("epoch: {} ".format(int(i / len_train_source)))
        if epoch % 5 ==0 and pl_update:
            pl_update= False
            # del dset_loaders["ps_target"]
            pseudo_labeled_targets,target_g_ctr, source_g_ctr = pseudo_labeling(base_network, cluster_data_loader, class_num)
            global_source_ctr = source_g_ctr.detach_()
            global_target_ctr = target_g_ctr.detach_()
            if len(pseudo_labeled_targets["label_list"]) !=0:
                print("new pl at epoch {}".format(epoch))

                pseudo_dataset = PS_ImageList(pseudo_labeled_targets, transform=prep_dict["target"])

                dset_loaders["ps_target"] = DataLoader(pseudo_dataset, batch_size=train_bs, \
                                                       shuffle=False, num_workers=0, drop_last=True)
                len_train_target = len(dset_loaders["ps_target"])
            else:
                print("no pl at epoch {}".format(epoch))
            # print("pseudo labeling done")
        # print(i)




        # if i % config["test_interval"] == config["test_interval"] - 1:

        if epoch % 5 ==0 and testing and i>0:


            base_network.train(False)
            temp_acc,v_loss = image_classification_test(dset_loaders, \
                                                 base_network, test_10crop=prep_config["test_10crop"])
            temp_model = nn.Sequential(base_network)
            if temp_acc > best_acc:
                best_acc = temp_acc
                best_model = temp_model
            log_str = "iter: {:05d}, precision: {:.5f}".format(i, temp_acc)
            config["out_file"].write(log_str + "\n")
            config["out_file"].flush()
            print(log_str)
            testing=False

            now = datetime.now()
            current_time = now.strftime("%H:%M:%S")
            print("epoch: {} ".format(int(i / len_train_source)))
            print("time: {} ".format(current_time))
            print("best acc: {} ".format(best_acc))
            print("loss: {} ".format(v_loss))
            print("adaptation rate : {}".format(lamb))
            print("learning rare : {} {} {} {}".format(optimizer.param_groups[0]["lr"],optimizer.param_groups[1]["lr"],optimizer.param_groups[2]["lr"],optimizer.param_groups[3]["lr"]))
            print("------------")
        if i % config["snapshot_interval"] == 0:
            torch.save(nn.Sequential(base_network), osp.join(config["output_path"], \
                                                             "iter_{:05d}_model.pth.tar".format(i)))

        loss_params = config["loss"]
        ## train one iter
        base_network.train(True)
        ad_net.train(True)
        optimizer = lr_scheduler(optimizer, i, **schedule_param)

        optimizer.zero_grad()



        ###
        if i % len_train_source == 0:
            iter_source = iter(dset_loaders["source"])
        if i % len_train_target == 0:
            # print(i,len_train_target)
            iter_target = iter(dset_loaders["ps_target"])
        try:
            inputs_source, labels_source, _ = iter_source.next()
            inputs_target, labels_target = iter_target.next()
        except StopIteration:
            iter_target = iter(dset_loaders["ps_target"])
            inputs_target, labels_target = iter_target.next()

        inputs_source, inputs_target, labels_source, labels_target = inputs_source.cuda(), inputs_target.cuda(), labels_source.cuda(), labels_target.cuda()



        features_source, outputs_source = base_network(inputs_source)
        features_target, outputs_target = base_network(inputs_target)

        ##class_aware
        batch_source_centroids = utils.get_batch_centers(features_source, labels_source, class_num)
        batch_target_centroids = utils.get_batch_centers(features_target,labels_target, class_num)

        # if i==0:
        #     global_source_ctr = batch_source_centroids
        #     global_target_ctr = batch_target_centroids
        # if i>0:
        batch_source_centroids = ctr_adapt_factor* global_source_ctr + (1- ctr_adapt_factor) * batch_source_centroids
        batch_target_centroids = ctr_adapt_factor * global_target_ctr + (1 - ctr_adapt_factor) * batch_target_centroids
        global_source_ctr = batch_source_centroids.clone().detach_()
        global_target_ctr = batch_target_centroids.clone().detach_()
        #
        # global_source_ctr = global_source_ctr.cpu().data.numpy()
        # global_target_ctr.detach_()

        # ctr_alignment_loss = utils.cosine_distance(global_source_ctr,global_target_ctr,cross=False)



        # source_p2c_Distances = 0 - utils.cosine_distance(features_source, global_source_ctr, cross=True)
        #
        # target_p2c_Distances = 0 - utils.cosine_distance(features_target, global_target_ctr, cross=True)
        #
        #
        #
        # zero_ctrs_s = torch.unique(torch.where(global_source_ctr==0)[0])
        # zero_ctrs_t = torch.unique(torch.where(global_target_ctr == 0)[0])
        alignment_index = []
        identity = np.eye(class_num)
        ctr_alignment_count =0
        pos = []
        post = []
        neg =[]
        negt =[]
        index_s = np.empty([0,1])
        index_t = np.empty([0,1])
        itt=0
        triplets ={}
        # with torch.no_grad():

        labels = labels_source.cpu().data.numpy()
        labelt = labels_target.cpu().data.numpy()
        # zero_ctrs_s = zero_ctrs_s.cpu().data.numpy()
        # zero_ctrs_t = zero_ctrs_t.cpu().data.numpy()

        #####npair
        # labels = labels.cpu().data.numpy()
        n_pairs = []

        for label in set(labels):
            label_mask = (labels == label)
            label_indices = np.where(label_mask)[0]
            if len(label_indices) < 1:
                continue
            anchor = np.random.choice(label_indices, 1, replace=False)
            n_pairs.append([anchor, np.array([label])])

        n_pairs = np.array(n_pairs)

        n_negatives = []
        for i in range(len(n_pairs)):
            negative = np.concatenate([n_pairs[:i, 1], n_pairs[i + 1:, 1]])
            n_negatives.append(negative)

        n_negatives = np.array(n_negatives)
        n_pairs_s = torch.LongTensor(n_pairs)
        n_neg_s = torch.LongTensor(n_negatives)

        n_pairs = []
        for label in set(labelt):
            label_mask = (labelt == label)
            label_indices = np.where(label_mask)[0]
            if len(label_indices) < 1:
                continue
            anchor = np.random.choice(label_indices, 1, replace=False)
            n_pairs.append([anchor, np.array([label])])

        n_pairs = np.array(n_pairs)

        n_negatives = []
        for i in range(len(n_pairs)):
            negative = np.concatenate([n_pairs[:i, 1], n_pairs[i + 1:, 1]])
            n_negatives.append(negative)

        n_negatives = np.array(n_negatives)
        n_pairs_t = torch.LongTensor(n_pairs)
        n_neg_t = torch.LongTensor(n_negatives)
        # return torch.LongTensor(n_pairs), torch.LongTensor(n_negatives)
        #####

        for it in range(class_num):
            label_mask = (labels == it)
            label_maskt = (labelt == it)
            idx = np.where(label_mask)[0]
            idxt = np.where(label_maskt)[0]
            # idx = torch.flatten(torch.nonzero(labels_source== torch.tensor(it).cuda()))
            if len(idx) !=0:
                index_s =np.append(index_s,idx)
                pos += [it for cc in range(len(idx))]
                mask = 1- identity[it,:]
                neg_id = np.nonzero(mask.flatten())[0].flatten()

                # neg_idx = np.where(np.in1d(neg_id,zero_ctrs_s)!=True)[0]
                neg += [[neg_id] for cc in range(len(idx))]

            if len(idxt) !=0:
                index_t = np.append(index_t, idxt)
                post += [it for cc in range(len(idxt))]
                maskt = 1- identity[it,:]
                neg_idt = np.nonzero(maskt.flatten())[0].flatten()
                # neg_idxt = np.where(np.in1d(neg_idt, zero_ctrs_t))[0]
                negt += [[neg_idt] for cc in range(len(idxt))]
                # negt += [[neg_idt] for cc in range(len(idxt))]

            # alignment_ctr_idx =idx[torch.nonzero(torch.where(idx ==idxt, idx,0))]
            if len(idx) != 0 and len(idxt) !=0:
                ctr_alignment_count +=1
                alignment_index +=[it]
                    # alignment_loss +=[utils.cosine_distance(batch_source_centroids[it], batch_source_centroids[it], cross=False)]
        # tempp = torch.cat(source_loss,0)
        # posetives_s = torch.cat(pos, dim=0)
        # negatives_s = torch.cat(neg, dim=0)
        # posetives_t = torch.cat(post, dim=0)
        # negatives_t = torch.cat(negt, dim=0)
        # a_i = torch.LongTensor(index_s.flatten()).cuda()
        # a_p = torch.LongTensor(pos).cuda()
        # a_n = torch.LongTensor(neg).cuda()
        ctr_alignment_loss =0
        anchors_s = features_source[index_s.flatten(),:]
        positive_s = global_source_ctr[pos,:]
        negative_s = global_source_ctr[neg].squeeze(1)
        # n_pairs_s = n_pairs_s.cuda().squeeze(2)
        # n_neg_s = n_neg_s.cuda().squeeze(2)
        # anchors_s = features_source[n_pairs_s[:, 0]]
        # positive_s = global_source_ctr[n_pairs_s[:, 1]]
        # negative_s = global_source_ctr[n_neg_s]
        #
        # n_pairs_t = n_pairs_t.cuda().squeeze(2)
        #
        # n_neg_t = n_neg_t.cuda().squeeze(2)
        # anchors_t = features_source[n_pairs_t[:, 0]]
        # positive_t = global_source_ctr[n_pairs_t[:, 1]]
        # negative_t = global_source_ctr[n_neg_t]
        # anchors_s.retain_graph=True
        # positive_s.retain_graph=True
        # negative_s.retain_graph=True

        anchors_t = features_target[index_t.flatten(), :]
        positive_t = global_target_ctr[post, :]
        negative_t = global_target_ctr[negt].squeeze(1)
        # FAT_loss = torch.empty([],requires_grad=True)
        # FAT_loss.requires_grad = True
        # FAT_loss.retain_grad()
        # nfat_s = Variable(n_pair_loss(anchors_s,positive_s, negative_s,class_num,train_bs))
        # nfat_t = Variable(n_pair_loss(anchors_t,positive_t, negative_t,class_num,train_bs))
        # FAT_loss.requires_grad = True
        # FAT_loss.retain_grad()
        FAT_loss = n_pair_loss(anchors_s,positive_s, negative_s,class_num,train_bs) + n_pair_loss(anchors_t,positive_t, negative_t,class_num,train_bs)/2

        if len(alignment_index) != 0:
            ctr_alignment_loss = torch.sum(utils.cosine_distance(batch_source_centroids[alignment_index], batch_target_centroids[alignment_index], cross=False))#/ctr_alignment_count
        # source_batch_FAT_Loss = torch.mean(torch.cat(source_loss,0), 0)/class_num
        # target_batch_FAT_Loss = torch.mean(torch.cat(target_loss,0),0)/class_num
        #
        # FAT_loss = source_batch_FAT_Loss.add(target_batch_FAT_Loss)
        ##
        # print("train loss: ", FAT_loss)
        # ctr_alignment_loss.grad_required =True
        # ctr_alignment_loss.retain_grad()
        features = torch.cat((features_source, features_target), dim=0)
        outputs = torch.cat((outputs_source, outputs_target), dim=0)
        softmax_out = nn.Softmax(dim=1)(outputs)
        if config['method'] == 'CDAN+E':
            entropy = loss.Entropy(softmax_out)
            transfer_loss = loss.CDAN([features, softmax_out], ad_net, entropy, network.calc_coeff(i), random_layer)
        elif config['method'] == 'CDAN':
            transfer_loss = loss.CDAN([features, softmax_out], ad_net, None, None, random_layer)
        elif config['method'] == 'DANN':
            transfer_loss = loss.DANN(features, ad_net)
        else:
            raise ValueError('Method cannot be recognized.')
        classifier_loss = nn.CrossEntropyLoss()(outputs_source/(2), labels_source)


        total_loss = loss_params["trade_off"] * (transfer_loss) + classifier_loss
         if lamb >.1:
            cls_lamb = 1.0
        else:
            cls_lamb = 10*lamb

        # total_loss = lamb * ( FAT_loss + 10*ctr_alignment_loss) + (transfer_loss) + cls_lamb*classifier_loss
        # total_loss =transfer_loss + lamb * (FAT_loss + ctr_alignment_loss) + classifier_loss
        # FAT_loss.backward(retain_graph=True)
        # optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        # my_lr_scheduler.step()
        if epoch % 5 ==0 and print_loss:
            print("fat loss ", FAT_loss)#.grad_fn, FAT_loss.requires_grad)
            print("ctr align:  ", ctr_alignment_loss)
            print("tot: ", total_loss)
            print("clss: ",classifier_loss)
            print("trs: ", transfer_loss)
            print("++++++++++++++++++++++++end of epoch++++++++++++++++++++")

            print_loss =False
def train(config):

    ## Define start time
    start_time = time.time()

    ## set pre-process
    prep_dict = {}
    prep_config = config["prep"]
    prep_dict["source"] = prep.image_train(**config["prep"]['params'])
    prep_dict["target"] = prep.image_train(**config["prep"]['params'])
    prep_dict["test"] = prep.image_test(**config["prep"]['params'])

    ## prepare data
    print("Preparing data", flush=True)
    dsets = {}
    dset_loaders = {}
    data_config = config["data"]
    train_bs = data_config["source"]["batch_size"]
    test_bs = data_config["test"]["batch_size"]
    root_folder = data_config["root_folder"]
    dsets["source"] = ImageList(open(osp.join(root_folder, data_config["source"]["list_path"])).readlines(), \
                                transform=prep_dict["source"], root_folder=root_folder, ratios=config["ratios_source"])
    dset_loaders["source"] = DataLoader(dsets["source"], batch_size=train_bs, \
            shuffle=True, num_workers=4, drop_last=True)
    dsets["target"] = ImageList(open(osp.join(root_folder, data_config["target"]["list_path"])).readlines(), \
                                transform=prep_dict["target"], root_folder=root_folder, ratios=config["ratios_target"])
    dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, \
            shuffle=True, num_workers=4, drop_last=True)

    dsets["test"] = ImageList(open(
        osp.join(root_folder, data_config["test"]["list_path"])).readlines(),
                              transform=prep_dict["test"],
                              root_folder=root_folder,
                              ratios=config["ratios_test"])
    dset_loaders["test"] = DataLoader(dsets["test"], batch_size=test_bs, \
                            shuffle=False, num_workers=4)

    test_path = os.path.join(root_folder, data_config["test"]["dataset_path"])
    if os.path.exists(test_path):
        print('Found existing dataset for test', flush=True)
        with open(test_path, 'rb') as f:
            [test_samples, test_labels] = pickle.load(f)
            test_labels = torch.LongTensor(test_labels).to(config["device"])
    else:
        print('Missing test dataset', flush=True)
        print('Building dataset for test and writing to {}'.format(test_path),
              flush=True)
        dset_test = ImageList(open(
            osp.join(root_folder,
                     data_config["test"]["list_path"])).readlines(),
                              transform=prep_dict["test"],
                              root_folder=root_folder,
                              ratios=config['ratios_test'])
        loaded_dset_test = LoadedImageList(dset_test)
        test_samples, test_labels = loaded_dset_test.samples.numpy(
        ), loaded_dset_test.targets.numpy()
        with open(test_path, 'wb') as f:
            pickle.dump([test_samples, test_labels], f)

    class_num = config["network"]["params"]["class_num"]
    test_samples, test_labels = sample_ratios(test_samples, test_labels,
                                              config['ratios_test'])

    # compute labels distribution on the source and target domain
    source_label_distribution = np.zeros((class_num))
    for img in dsets["source"].imgs:
        source_label_distribution[img[1]] += 1
    print("Total source samples: {}".format(np.sum(source_label_distribution)),
          flush=True)
    print("Source samples per class: {}".format(source_label_distribution),
          flush=True)
    source_label_distribution /= np.sum(source_label_distribution)
    print("Source label distribution: {}".format(source_label_distribution),
          flush=True)
    target_label_distribution = np.zeros((class_num))
    for img in dsets["target"].imgs:
        target_label_distribution[img[1]] += 1
    print("Total target samples: {}".format(np.sum(target_label_distribution)),
          flush=True)
    print("Target samples per class: {}".format(target_label_distribution),
          flush=True)
    target_label_distribution /= np.sum(target_label_distribution)
    print("Target label distribution: {}".format(target_label_distribution),
          flush=True)
    mixture = (source_label_distribution + target_label_distribution) / 2
    jsd = (scipy.stats.entropy(source_label_distribution, qk=mixture) \
            + scipy.stats.entropy(target_label_distribution, qk=mixture)) / 2
    print("JSD : {}".format(jsd), flush=True)

    test_label_distribution = np.zeros((class_num))
    for img in test_labels:
        test_label_distribution[int(img.item())] += 1
    print("Test samples per class: {}".format(test_label_distribution),
          flush=True)
    test_label_distribution /= np.sum(test_label_distribution)
    print("Test label distribution: {}".format(test_label_distribution),
          flush=True)
    write_list(config["out_wei_file"],
               [round(x, 4) for x in test_label_distribution])
    write_list(config["out_wei_file"],
               [round(x, 4) for x in source_label_distribution])
    write_list(config["out_wei_file"],
               [round(x, 4) for x in target_label_distribution])
    true_weights = torch.tensor(
        target_label_distribution / source_label_distribution,
        dtype=torch.float,
        requires_grad=False)[:, None].to(config["device"])
    print("True weights : {}".format(true_weights[:, 0].cpu().numpy()))
    config["out_wei_file"].write(str(jsd) + "\n")

    ## set base network
    net_config = config["network"]
    base_network = net_config["name"](**net_config["params"])
    base_network = base_network.to(config["device"])

    ## add additional network for some methods
    if config["loss"]["random"]:
        random_layer = network.RandomLayer(
            [base_network.output_num(), class_num],
            config["loss"]["random_dim"])
        ad_net = network.AdversarialNetwork(config["loss"]["random_dim"], 1024)
    else:
        random_layer = None
        if 'CDAN' in config['method']:
            ad_net = network.AdversarialNetwork(
                base_network.output_num() * class_num, 1024)
        else:
            ad_net = network.AdversarialNetwork(base_network.output_num(),
                                                1024)
    if config["loss"]["random"]:
        random_layer.to(config["device"])
    ad_net = ad_net.to(config["device"])
    parameter_list = ad_net.get_parameters() + base_network.get_parameters()
    parameter_list[-1]["lr_mult"] = config["lr_mult_im"]

    ## set optimizer
    optimizer_config = config["optimizer"]
    optimizer = optimizer_config["type"](parameter_list, \
                    **(optimizer_config["optim_params"]))
    param_lr = []
    for param_group in optimizer.param_groups:
        param_lr.append(param_group["lr"])
    schedule_param = optimizer_config["lr_param"]
    lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]]

    # Maintain two quantities for the QP.
    cov_mat = torch.tensor(np.zeros((class_num, class_num), dtype=np.float32),
                           requires_grad=False).to(config["device"])
    pseudo_target_label = torch.tensor(np.zeros((class_num, 1),
                                                dtype=np.float32),
                                       requires_grad=False).to(
                                           config["device"])
    # Maintain one weight vector for BER.
    class_weights = torch.tensor(1.0 / source_label_distribution,
                                 dtype=torch.float,
                                 requires_grad=False).to(config["device"])

    gpus = config['gpu'].split(',')
    if len(gpus) > 1:
        ad_net = nn.DataParallel(ad_net, device_ids=[int(i) for i in gpus])
        base_network = nn.DataParallel(base_network,
                                       device_ids=[int(i) for i in gpus])

    ## train
    len_train_source = len(dset_loaders["source"])
    len_train_target = len(dset_loaders["target"])
    transfer_loss_value = classifier_loss_value = total_loss_value = 0.0
    best_acc = 0.0

    print("Preparations done in {:.0f} seconds".format(time.time() -
                                                       start_time),
          flush=True)
    print("Starting training for {} iterations using method {}".format(
        config["num_iterations"], config['method']),
          flush=True)
    start_time_test = start_time = time.time()
    for i in range(config["num_iterations"]):
        if i % config["test_interval"] == config["test_interval"] - 1:
            base_network.train(False)
            temp_acc = image_classification_test_loaded(
                test_samples, test_labels, base_network)
            temp_model = nn.Sequential(base_network)
            if temp_acc > best_acc:
                best_acc = temp_acc
            log_str = "  iter: {:05d}, sec: {:.0f}, class: {:.5f}, da: {:.5f}, precision: {:.5f}".format(
                i,
                time.time() - start_time_test, classifier_loss_value,
                transfer_loss_value, temp_acc)
            config["out_log_file"].write(log_str + "\n")
            config["out_log_file"].flush()
            print(log_str, flush=True)
            if 'IW' in config['method']:
                current_weights = [
                    round(x, 4) for x in
                    base_network.im_weights.data.cpu().numpy().flatten()
                ]
                # write_list(config["out_wei_file"], current_weights)
                print(current_weights, flush=True)
            start_time_test = time.time()
        if i % 500 == -1:
            print("{} iterations in {} seconds".format(
                i,
                time.time() - start_time),
                  flush=True)

        loss_params = config["loss"]
        ## train one iter
        base_network.train(True)
        ad_net.train(True)
        optimizer = lr_scheduler(optimizer, i, **schedule_param)
        optimizer.zero_grad()

        t = time.time()
        if i % len_train_source == 0:
            iter_source = iter(dset_loaders["source"])
        if i % len_train_target == 0:
            iter_target = iter(dset_loaders["target"])
        inputs_source, label_source = iter_source.next()
        inputs_target, _ = iter_target.next()
        inputs_source, inputs_target, label_source = inputs_source.to(
            config["device"]), inputs_target.to(
                config["device"]), label_source.to(config["device"])
        features_source, outputs_source = base_network(inputs_source)
        features_target, outputs_target = base_network(inputs_target)
        features = torch.cat((features_source, features_target), dim=0)
        outputs = torch.cat((outputs_source, outputs_target), dim=0)
        softmax_out = nn.Softmax(dim=1)(outputs)

        if 'IW' in config['method']:
            ys_onehot = torch.zeros(train_bs, class_num).to(config["device"])
            ys_onehot.scatter_(1, label_source.view(-1, 1), 1)

            # Compute weights on source data.
            if 'ORACLE' in config['method']:
                weights = torch.mm(ys_onehot, true_weights)
            else:
                weights = torch.mm(ys_onehot, base_network.im_weights)

            source_preds, target_preds = outputs[:train_bs], outputs[train_bs:]
            # Compute the aggregated distribution of pseudo-label on the target domain.
            pseudo_target_label += torch.sum(F.softmax(target_preds, dim=1),
                                             dim=0).view(-1, 1).detach()
            # Update the covariance matrix on the source domain as well.
            cov_mat += torch.mm(
                F.softmax(source_preds, dim=1).transpose(1, 0),
                ys_onehot).detach()

        if config['method'] == 'CDAN-E':
            classifier_loss = nn.CrossEntropyLoss()(outputs_source,
                                                    label_source)
            entropy = loss.Entropy(softmax_out)
            transfer_loss = loss.CDAN([features, softmax_out], ad_net, entropy,
                                      network.calc_coeff(i), random_layer)
            total_loss = loss_params["trade_off"] * \
                transfer_loss + classifier_loss

        elif 'IWCDAN-E' in config['method']:

            classifier_loss = torch.mean(
                nn.CrossEntropyLoss(weight=class_weights, reduction='none')
                (outputs_source, label_source) * weights) / class_num

            entropy = loss.Entropy(softmax_out)
            transfer_loss = loss.CDAN([features, softmax_out],
                                      ad_net,
                                      entropy,
                                      network.calc_coeff(i),
                                      random_layer,
                                      weights=weights,
                                      device=config["device"])
            total_loss = loss_params["trade_off"] * \
                transfer_loss + classifier_loss

        elif config['method'] == 'CDAN':

            classifier_loss = nn.CrossEntropyLoss()(outputs_source,
                                                    label_source)
            transfer_loss = loss.CDAN([features, softmax_out], ad_net, None,
                                      None, random_layer)
            total_loss = loss_params[
                "trade_off"] * transfer_loss + classifier_loss

        elif 'IWCDAN' in config['method']:

            classifier_loss = torch.mean(
                nn.CrossEntropyLoss(weight=class_weights, reduction='none')
                (outputs_source, label_source) * weights) / class_num

            transfer_loss = loss.CDAN([features, softmax_out],
                                      ad_net,
                                      None,
                                      None,
                                      random_layer,
                                      weights=weights)
            total_loss = loss_params["trade_off"] * \
                transfer_loss + classifier_loss

        elif config['method'] == 'DANN':
            classifier_loss = nn.CrossEntropyLoss()(outputs_source,
                                                    label_source)
            transfer_loss = loss.DANN(features, ad_net, config["device"])
            total_loss = loss_params["trade_off"] * \
                transfer_loss + classifier_loss

        elif 'IWDAN' in config['method']:

            classifier_loss = torch.mean(
                nn.CrossEntropyLoss(weight=class_weights, reduction='none')
                (outputs_source, label_source) * weights) / class_num

            transfer_loss = loss.IWDAN(features, ad_net, weights)
            total_loss = loss_params["trade_off"] * \
                transfer_loss + classifier_loss

        elif config['method'] == 'NANN':
            classifier_loss = nn.CrossEntropyLoss()(outputs_source,
                                                    label_source)
            total_loss = classifier_loss
        else:
            raise ValueError('Method cannot be recognized.')

        total_loss.backward()
        optimizer.step()

        transfer_loss_value = 0 if config[
            'method'] == 'NANN' else transfer_loss.item()
        classifier_loss_value = classifier_loss.item()
        total_loss_value = transfer_loss_value + classifier_loss_value

        if ('IW' in config['method']
            ) and i % (config["dataset_mult_iw"] * len_train_source
                       ) == config["dataset_mult_iw"] * len_train_source - 1:

            pseudo_target_label /= train_bs * \
                len_train_source * config["dataset_mult_iw"]
            cov_mat /= train_bs * len_train_source * config["dataset_mult_iw"]
            print(i, np.sum(cov_mat.cpu().detach().numpy()),
                  train_bs * len_train_source)

            # Recompute the importance weight by solving a QP.
            base_network.im_weights_update(
                source_label_distribution,
                pseudo_target_label.cpu().detach().numpy(),
                cov_mat.cpu().detach().numpy(), config["device"])
            current_weights = [
                round(x, 4)
                for x in base_network.im_weights.data.cpu().numpy().flatten()
            ]
            write_list(config["out_wei_file"], [
                np.linalg.norm(current_weights -
                               true_weights.cpu().numpy().flatten())
            ] + current_weights)
            print(
                np.linalg.norm(current_weights -
                               true_weights.cpu().numpy().flatten()),
                current_weights)

            cov_mat[:] = 0.0
            pseudo_target_label[:] = 0.0

    return best_acc
Пример #13
0
def train(config):
    print("Deep copy of model with margin as 1.0")
    ## set pre-process
    prep_dict = {}
    prep_config = config["prep"]
    prep_dict["source"] = prep.image_train(**config["prep"]['params'])
    prep_dict["target"] = prep.image_train(**config["prep"]['params'])
    if prep_config["test_10crop"]:
        prep_dict["test"] = prep.image_test_10crop(**config["prep"]['params'])
    else:
        prep_dict["test"] = prep.image_test(**config["prep"]['params'])

    ## prepare data
    dsets = {}
    dset_loaders = {}
    data_config = config["data"]
    train_bs = data_config["source"]["batch_size"]
    test_bs = data_config["test"]["batch_size"]
    dsets["source"] = ImageList(open(data_config["source"]["list_path"]).readlines(), \
                                transform=prep_dict["source"])
    dset_loaders["source"] = DataLoader(dsets["source"], batch_size=train_bs, \
            shuffle=True, num_workers=4, drop_last=True)
    dsets["target"] = ImageList(open(data_config["target"]["list_path"]).readlines(), \
                                transform=prep_dict["target"])
    dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, \
            shuffle=True, num_workers=4, drop_last=True)

    if prep_config["test_10crop"]:
        for i in range(10):
            dsets["test"] = [ImageList(open(data_config["test"]["list_path"]).readlines(), \
                                transform=prep_dict["test"][i]) for i in range(10)]
            dset_loaders["test"] = [DataLoader(dset, batch_size=test_bs, \
                                shuffle=False, num_workers=4) for dset in dsets['test']]
    else:
        dsets["test"] = ImageList(open(data_config["test"]["list_path"]).readlines(), \
                                transform=prep_dict["test"])
        dset_loaders["test"] = DataLoader(dsets["test"], batch_size=test_bs, \
                                shuffle=False, num_workers=4)

    class_num = config["network"]["params"]["class_num"]

    ## set base network
    net_config = config["network"]
    base_network = net_config["call"](net_config["name"],
                                      **net_config["params"])
    base_network_teacher = net_config["call"](net_config["name"],
                                              **net_config["params_teacher"])
    base_network = base_network.cuda()
    base_network_teacher = copy.deepcopy(base_network).cuda()
    for param in base_network_teacher.parameters():
        param.detach_()
    # base_network_teacher = base_network_teacher.cuda()

    # print("check init: ", torch.equal(base_network.fc.weight, base_network_teacher.fc.weight))

    base_network.layer1[-1].relu = nn.ReLU()
    base_network.layer2[-1].relu = nn.ReLU()
    base_network.layer3[-1].relu = nn.ReLU()
    base_network.layer4[-1].relu = nn.ReLU()

    base_network_teacher.layer1[-1].relu = nn.ReLU()
    base_network_teacher.layer2[-1].relu = nn.ReLU()
    base_network_teacher.layer3[-1].relu = nn.ReLU()
    base_network_teacher.layer4[-1].relu = nn.ReLU()

    # print(base_network)

    for n, m in base_network.named_modules():
        if n == 'layer1.2.bn3' or 'layer2.3.bn3' or 'layer3.5.bn3' or 'layer4.2.bn3':
            m.register_forward_hook(get_activation_student(n))

    if config["loss"]["random"]:
        random_layer = network.RandomLayer(
            [base_network.output_num(), class_num],
            config["loss"]["random_dim"])
        ad_net = network.AdversarialNetwork(config["loss"]["random_dim"], 1024)
    else:
        random_layer = None
        ad_net = network.AdversarialNetwork(
            base_network.output_num() * class_num, 1024)
    if config["loss"]["random"]:
        random_layer.cuda()
    ad_net = ad_net.cuda()
    parameter_list = base_network.get_parameters() + ad_net.get_parameters()
    Hloss = loss.Entropy()
    ## set optimizer
    optimizer_config = config["optimizer"]
    optimizer = optimizer_config["type"](parameter_list, \
                    **(optimizer_config["optim_params"]))
    param_lr = []
    for param_group in optimizer.param_groups:
        param_lr.append(param_group["lr"])
    schedule_param = optimizer_config["lr_param"]
    lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]]

    gpus = config['gpu'].split(',')
    if len(gpus) > 1:
        ad_net = nn.DataParallel(ad_net, device_ids=[int(i) for i in gpus])
        base_network = nn.DataParallel(base_network,
                                       device_ids=[int(i) for i in gpus])

    ## train
    len_train_source = len(dset_loaders["source"])
    len_train_target = len(dset_loaders["target"])
    transfer_loss_value = classifier_loss_value = total_loss_value = 0.0
    best_acc = 0.0
    temperature = config["temperature"]

    for i in trange(config["num_iterations"], leave=False):
        global activation_student
        if i % config["test_interval"] == config["test_interval"] - 1:
            base_network.eval()
            base_network_teacher.eval()
            temp_acc, temp_acc_teacher = image_classification_test(dset_loaders, \
                base_network, base_network_teacher, test_10crop=prep_config["test_10crop"])
            temp_model = nn.Sequential(base_network_teacher)
            if temp_acc > best_acc:
                best_acc = temp_acc
                best_model = temp_model
            log_str = "iter: {:05d}, precision: {:.5f}".format(i, temp_acc)
            log_str1 = "precision: {:.5f}".format(temp_acc_teacher)
            config["out_file"].write(log_str + "\t" + log_str1 + "\t" +
                                     str(classifier_loss.item()) + "\t" +
                                     str(dann_loss.item()) + "\t" +
                                     str(ent_loss.item()) + "\t" + "\n")
            config["out_file"].flush()
            print("ent Loss: ", ent_loss.item())
            print("Dann loss: ", dann_loss.item())
            print("Classification Loss: ", classifier_loss.item())
            print(log_str)
            print(log_str1)
        # if i % config["snapshot_interval"] == 0:
        #     torch.save(nn.Sequential(base_network), osp.join(config["output_path"], \
        #         "iter_{:05d}_model.pth.tar".format(i)))

        loss_params = config["loss"]
        ## train one iter
        base_network.train(True)
        base_network_teacher.train(True)

        ad_net.train(True)
        optimizer = lr_scheduler(optimizer, i, **schedule_param)
        optimizer.zero_grad()
        if i % len_train_source == 0:
            iter_source = iter(dset_loaders["source"])
        if i % len_train_target == 0:
            iter_target = iter(dset_loaders["target"])

        inputs_source, labels_source = iter_source.next()
        inputs_target, labels_target = iter_target.next()

        inputs_source1, inputs_source2, inputs_target1, inputs_target2, labels_source = utils.get_copies(
            inputs_source, inputs_target, labels_source)

        margin = 1
        loss_alter = 0

        #### For source data

        features_source, outputs_source = base_network(inputs_source1)
        # features_source2, outputs_source2 = base_network(inputs_source2)

        feature1 = base_network_teacher.features1(inputs_source2)
        feature2 = base_network_teacher.features2(feature1)
        feature3 = base_network_teacher.features3(feature2)
        feature4 = base_network_teacher.features4(feature3)
        feature4_avg = base_network_teacher.avgpool(feature4)
        feature4_res = feature4_avg.view(feature4_avg.size(0), -1)
        features_source2 = base_network_teacher.bottleneck(feature4_res)
        outputs_source2 = base_network_teacher.fc(features_source2)

        loss_alter += loss.decision_boundary_transfer(
            activation_student['layer1.2.bn3'], feature1.detach(), margin) / (
                train_bs * activation_student['layer1.2.bn3'].size(1) * 8)
        loss_alter += loss.decision_boundary_transfer(
            activation_student['layer2.3.bn3'], feature2.detach(), margin) / (
                train_bs * activation_student['layer2.3.bn3'].size(1) * 4)
        loss_alter += loss.decision_boundary_transfer(
            activation_student['layer3.5.bn3'], feature3.detach(), margin) / (
                train_bs * activation_student['layer3.5.bn3'].size(1) * 2)
        loss_alter += loss.decision_boundary_transfer(
            activation_student['layer4.2.bn3'], feature4.detach(),
            margin) / (train_bs * activation_student['layer4.2.bn3'].size(1))

        ## For Target data
        ramp = utils.sigmoid_rampup(i, 100004)
        ramp_confidence = utils.sigmoid_rampup(5 * i, 100004)

        features_target, outputs_target = base_network(inputs_target1)
        sample_selection_indices = get_confident_idx.confident_samples(
            base_network, inputs_target1, ramp_confidence, class_num, train_bs)

        confident_targets = utils.subsample(outputs_target,
                                            sample_selection_indices)

        feature1_teacher = base_network_teacher.features1(inputs_target2)
        feature2_teacher = base_network_teacher.features2(feature1_teacher)
        feature3_teacher = base_network_teacher.features3(feature2_teacher)
        feature4_teacher = base_network_teacher.features4(feature3_teacher)
        feature4_teacher_avg = base_network_teacher.avgpool(feature4_teacher)
        feature4_teacher_res = feature4_teacher_avg.view(
            feature4_teacher_avg.size(0), -1)
        features_target2 = base_network_teacher.bottleneck(
            feature4_teacher_res)
        outputs_target2 = base_network_teacher.fc(features_target2)

        loss_alter += loss.decision_boundary_transfer(
            activation_student['layer1.2.bn3'], feature1_teacher.detach(),
            margin) / (train_bs * activation_student['layer1.2.bn3'].size(1) *
                       8)
        loss_alter += loss.decision_boundary_transfer(
            activation_student['layer2.3.bn3'], feature2_teacher.detach(),
            margin) / (train_bs * activation_student['layer2.3.bn3'].size(1) *
                       4)
        loss_alter += loss.decision_boundary_transfer(
            activation_student['layer3.5.bn3'], feature3_teacher.detach(),
            margin) / (train_bs * activation_student['layer3.5.bn3'].size(1) *
                       2)
        loss_alter += loss.decision_boundary_transfer(
            activation_student['layer4.2.bn3'], feature4_teacher.detach(),
            margin) / (train_bs * activation_student['layer4.2.bn3'].size(1))

        loss_alter = loss_alter / 1000  ## May be multiply with 4 later in tests
        loss_alter = loss_alter.unsqueeze(0).unsqueeze(1)

        features = torch.cat((features_source, features_target), dim=0)
        outputs = torch.cat((outputs_source, outputs_target), dim=0)
        softmax_out_src = nn.Softmax(dim=1)(outputs_source)
        softmax_out_tar = nn.Softmax(dim=1)(outputs_target)
        softmax_out = nn.Softmax(dim=1)(outputs)

        features_teacher = torch.cat((features_source2, features_target2),
                                     dim=0)
        outputs_teacher = torch.cat((outputs_source2, outputs_target2), dim=0)
        softmax_out_src_teacher = nn.Softmax(dim=1)(outputs_source2)
        softmax_out_tar_teacher = nn.Softmax(dim=1)(outputs_target2)
        softmax_out_teacher = nn.Softmax(dim=1)(outputs_teacher)

        if config['method'] == 'DANN+E':
            ent_loss = Hloss(confident_targets)
            dann_loss = loss.DANN(features, ad_net)
        elif config['method'] == 'DANN':
            dann_loss = loss.DANN(features, ad_net)
            # dann_loss = 0
        else:
            raise ValueError('Method cannot be recognized.')
        classifier_loss = nn.CrossEntropyLoss()(outputs_source, labels_source)
        # loss_KD = -(F.softmax(outputs_teacher/ temperature, 1).detach() *
        # 	        (F.log_softmax(outputs/temperature, 1) - F.log_softmax(outputs_teacher/temperature, 1).detach())).sum() / train_bs
        # print(loss_KD)
        # total_loss =  loss_alter #+ (config["ent_loss"] * ent_loss)

        total_loss = dann_loss + classifier_loss + (
            ramp * ent_loss)  #+ (config["ent_loss"] * ent_loss)
        total_loss.backward(retain_graph=True)
        optimizer.step()
        loss.update_ema_variables(base_network, base_network_teacher,
                                  config["teacher_alpha"], i)
    torch.save(best_model, osp.join(config["output_path"],
                                    "best_model.pth.tar"))
    return best_acc
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='CDAN USPS MNIST')
    parser.add_argument('--method', type=str, default='CDAN-E', choices=['CDAN', 'CDAN-E', 'DANN'])
    parser.add_argument('--task', default='MNIST2USPS', help='task to perform')
    parser.add_argument('--batch_size', type=int, default=64,
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test_batch_size', type=int, default=1000,
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=10, metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
                        help='SGD momentum (default: 0.5)')
    parser.add_argument('--gpu_id', type=str, default='0',
                        help='cuda device id')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--log_interval', type=int, default=10,
                        help='how many batches to wait before logging training status')
    parser.add_argument('--random', type=bool, default=False,
                        help='whether to use random')
    args = parser.parse_args()

    torch.manual_seed(args.seed)
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id


    source_list = '/data/mnist/mnist_train.txt'
    target_list = '/data/usps/usps_train.txt'
    test_list = '/data/usps/usps_test.txt'
    start_epoch = 1
    decay_epoch = 5


    train_loader = torch.utils.data.DataLoader(
        ImageList(open(source_list).readlines(), transform=transforms.Compose([
                           transforms.Resize((28,28)),
                           transforms.ToTensor(),
                           transforms.Normalize((0.5,), (0.5,))
                       ]), mode='L'),
        batch_size=args.batch_size, shuffle=True, num_workers=1, drop_last=True)
    train_loader1 = torch.utils.data.DataLoader(
        ImageList(open(target_list).readlines(), transform=transforms.Compose([
                           transforms.Resize((28,28)),
                           transforms.ToTensor(),
                           transforms.Normalize((0.5,), (0.5,))
                       ]), mode='L'),
        batch_size=args.batch_size, shuffle=True, num_workers=1, drop_last=True)
    test_loader = torch.utils.data.DataLoader(
        ImageList(open(test_list).readlines(), transform=transforms.Compose([
                           transforms.Resize((28,28)),
                           transforms.ToTensor(),
                           transforms.Normalize((0.5,), (0.5,))
                       ]), mode='L'),
        batch_size=args.test_batch_size, shuffle=True, num_workers=1)

    model = network.LeNet()
    # model = model.cuda()
    class_num = 10

    if args.random:
        random_layer = network.RandomLayer([model.output_num(), class_num], 500)
        ad_net = network.AdversarialNetwork(500, 500)
        # random_layer.cuda()
    else:
        random_layer = None
        ad_net = network.AdversarialNetwork(model.output_num() * class_num, 500)
    # ad_net = ad_net.cuda()
    optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=0.0005, momentum=0.9)
    optimizer_ad = optim.SGD(ad_net.parameters(), lr=args.lr, weight_decay=0.0005, momentum=0.9)

    for epoch in range(1, args.epochs + 1):
        if epoch % decay_epoch == 0:
            for param_group in optimizer.param_groups:
                param_group["lr"] = param_group["lr"] * 0.5
        train(args, model, ad_net, random_layer, train_loader, train_loader1, optimizer, optimizer_ad, epoch, start_epoch, args.method)
        test(args, model, test_loader)
Пример #15
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='CDAN USPS MNIST')
    parser.add_argument('--method',
                        type=str,
                        default='CDAN-E',
                        choices=['CDAN', 'CDAN-E', 'DANN'])
    parser.add_argument('--task',
                        default='MNIST2USPS',
                        help='MNIST2USPS or MNIST2USPS')
    parser.add_argument('--batch_size',
                        type=int,
                        default=64,
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test_batch_size',
                        type=int,
                        default=1000,
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=100,
                        metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.01,
                        metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.5,
                        metavar='M',
                        help='SGD momentum (default: 0.5)')
    parser.add_argument('--gpu_id',
                        default='0',
                        type=str,
                        help='cuda device id')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument(
        '--log_interval',
        type=int,
        default=10,
        help='how many batches to wait before logging training status')
    parser.add_argument('--random',
                        type=bool,
                        default=False,
                        help='whether to use random')
    parser.add_argument('--mdd_weight', type=float, default=0.05)
    parser.add_argument('--entropic_weight', type=float, default=0)
    parser.add_argument("--use_seed", type=bool, default=True)
    args = parser.parse_args()
    import random
    if (args.use_seed):
        torch.manual_seed(args.seed)

        np.random.seed(args.seed)
        random.seed(args.seed)
        torch.backends.cudnn.deterministic = True
    import os.path as osp
    import datetime
    config = {}
    config["output_path"] = "snapshot/" + args.task
    config['seed'] = args.seed
    config["torch_seed"] = torch.initial_seed()
    config["torch_cuda_seed"] = torch.cuda.initial_seed()

    config["mdd_weight"] = args.mdd_weight
    config["entropic_weight"] = args.entropic_weight
    if not osp.exists(config["output_path"]):
        os.system('mkdir -p ' + config["output_path"])
    config["out_file"] = open(
        osp.join(
            config["output_path"],
            "log_{}_{}.txt".format(args.task,
                                   str(datetime.datetime.utcnow()))), "w")

    torch.manual_seed(args.seed)
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id

    if args.task == 'USPS2MNIST':
        source_list = 'data/usps2mnist/usps_train.txt'
        target_list = 'data/usps2mnist/mnist_train.txt'
        test_list = 'data/usps2mnist/mnist_test.txt'
        start_epoch = 1
        decay_epoch = 6
    elif args.task == 'MNIST2USPS':
        source_list = 'data/usps2mnist/mnist_train.txt'
        target_list = 'data/usps2mnist/usps_train.txt'
        test_list = 'data/usps2mnist/usps_test.txt'
        start_epoch = 1
        decay_epoch = 5
    else:
        raise Exception('task cannot be recognized!')

    train_loader = torch.utils.data.DataLoader(ImageList(
        open(source_list).readlines(),
        transform=transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, ), (0.5, ))
        ]),
        mode='L'),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=1,
                                               drop_last=True)
    train_loader1 = torch.utils.data.DataLoader(ImageList(
        open(target_list).readlines(),
        transform=transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, ), (0.5, ))
        ]),
        mode='L'),
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                num_workers=1,
                                                drop_last=True)
    test_loader = torch.utils.data.DataLoader(ImageList(
        open(test_list).readlines(),
        transform=transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, ), (0.5, ))
        ]),
        mode='L'),
                                              batch_size=args.test_batch_size,
                                              shuffle=True,
                                              num_workers=1)

    model = network.LeNet()
    model = model.cuda()
    class_num = 10

    if args.random:
        random_layer = network.RandomLayer([model.output_num(), class_num],
                                           500)
        ad_net = network.AdversarialNetwork(500, 500)
        random_layer.cuda()
    else:
        random_layer = None
        ad_net = network.AdversarialNetwork(model.output_num() * class_num,
                                            500)
    ad_net = ad_net.cuda()
    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          weight_decay=0.0005,
                          momentum=0.9)
    optimizer_ad = optim.SGD(ad_net.parameters(),
                             lr=args.lr,
                             weight_decay=0.0005,
                             momentum=0.9)
    config["out_file"].write(str(config) + "\n")
    config["out_file"].flush()
    for epoch in range(1, args.epochs + 1):
        if epoch % decay_epoch == 0:
            for param_group in optimizer.param_groups:
                param_group["lr"] = param_group["lr"] * 0.5
        train(args, model, ad_net, random_layer, train_loader, train_loader1,
              optimizer, optimizer_ad, epoch, start_epoch, args.method)
        test(args, epoch, config, model, test_loader)
Пример #16
0
def train(config):
    ## set pre-process
    prep_dict = {}
    prep_config = config["prep"]
    prep_dict["source"] = prep.image_train(**config["prep"]['params'])
    prep_dict["target"] = prep.image_train(**config["prep"]['params'])
    if prep_config["test_10crop"]:
        prep_dict["test"] = prep.image_test_10crop(**config["prep"]['params'])
    else:
        prep_dict["test"] = prep.image_test(**config["prep"]['params'])

    ## prepare data
    dsets = {}
    dset_loaders = {}
    data_config = config["data"]
    train_bs = data_config["source"]["batch_size"]
    test_bs = data_config["test"]["batch_size"]
    dsets["source"] = ImageList(open(data_config["source"]["list_path"]).readlines(), \
                                transform=prep_dict["source"])
    dset_loaders["source"] = DataLoader(dsets["source"], batch_size=train_bs, \
                                        shuffle=True, num_workers=0, drop_last=True)
    dsets["target"] = ImageList(open(data_config["target"]["list_path"]).readlines(), \
                                transform=prep_dict["target"])
    dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, \
                                        shuffle=True, num_workers=0, drop_last=True)

    if prep_config["test_10crop"]:
        for i in range(10):
            dsets["test"] = [ImageList(open(data_config["test"]["list_path"]).readlines(), \
                                       transform=prep_dict["test"][i]) for i in range(10)]
            dset_loaders["test"] = [DataLoader(dset, batch_size=test_bs, \
                                               shuffle=False, num_workers=0) for dset in dsets['test']]
    else:
        dsets["test"] = ImageList(open(data_config["test"]["list_path"]).readlines(), \
                                  transform=prep_dict["test"])
        dset_loaders["test"] = DataLoader(dsets["test"], batch_size=test_bs, \
                                          shuffle=False, num_workers=0)

    class_num = config["network"]["params"]["class_num"]

    ## set base network
    net_config = config["network"]
    base_network = net_config["name"](**net_config["params"])
    # base_network = base_network.cuda()

    ## 添加判别器D_s,D_t,生成器G_s2t,G_t2s

    z_dimension = 256
    D_s = network.models["Discriminator"]()
    # D_s = D_s.cuda()
    G_s2t = network.models["Generator"](z_dimension, 1024)
    # G_s2t = G_s2t.cuda()

    D_t = network.models["Discriminator"]()
    # D_t = D_t.cuda()
    G_t2s = network.models["Generator"](z_dimension, 1024)
    # G_t2s = G_t2s.cuda()

    criterion_GAN = torch.nn.MSELoss()
    criterion_cycle = torch.nn.L1Loss()
    criterion_identity = torch.nn.L1Loss()
    criterion_Sem = torch.nn.L1Loss()

    optimizer_G = torch.optim.Adam(itertools.chain(G_s2t.parameters(), G_t2s.parameters()), lr=0.0003)
    optimizer_D_s = torch.optim.Adam(D_s.parameters(), lr=0.0003)
    optimizer_D_t = torch.optim.Adam(D_t.parameters(), lr=0.0003)

    fake_S_buffer = ReplayBuffer()
    fake_T_buffer = ReplayBuffer()

    classifier_optimizer = torch.optim.Adam(base_network.parameters(), lr=0.0003)
    ## 添加分类器
    classifier1 = net.Net(256,class_num)
    # classifier1 = classifier1.cuda()
    classifier1_optim = optim.Adam(classifier1.parameters(), lr=0.0003)

    ## add additional network for some methods
    if config["loss"]["random"]:
        random_layer = network.RandomLayer([base_network.output_num(), class_num], config["loss"]["random_dim"])
        ad_net = network.AdversarialNetwork(config["loss"]["random_dim"], 1024)
    else:
        random_layer = None
        ad_net = network.AdversarialNetwork(base_network.output_num() * class_num, 1024)
    if config["loss"]["random"]:
        random_layer.cuda()
    # ad_net = ad_net.cuda()
    parameter_list = base_network.get_parameters() + ad_net.get_parameters()

    ## set optimizer
    optimizer_config = config["optimizer"]
    optimizer = optimizer_config["type"](parameter_list, \
                                         **(optimizer_config["optim_params"]))
    param_lr = []
    for param_group in optimizer.param_groups:
        param_lr.append(param_group["lr"])
    schedule_param = optimizer_config["lr_param"]
    lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]]

    gpus = config['gpu'].split(',')
    if len(gpus) > 1:
        ad_net = nn.DataParallel(ad_net, device_ids=[int(i) for i in gpus])
        base_network = nn.DataParallel(base_network, device_ids=[int(i) for i in gpus])

    ## train
    len_train_source = len(dset_loaders["source"])
    len_train_target = len(dset_loaders["target"])
    transfer_loss_value = classifier_loss_value = total_loss_value = 0.0
    best_acc = 0.0
    for i in range(config["num_iterations"]):
        if i % config["test_interval"] == config["test_interval"] - 1:
            base_network.train(False)
            temp_acc = image_classification_test(dset_loaders, \
                                                 base_network, test_10crop=prep_config["test_10crop"])
            temp_model = nn.Sequential(base_network)
            if temp_acc > best_acc:
                best_acc = temp_acc
                best_model = temp_model

                now = datetime.datetime.now()
                d = str(now.month) + '-' + str(now.day) + ' ' + str(now.hour) + ':' + str(now.minute) + ":" + str(
                    now.second)
                torch.save(best_model, osp.join(config["output_path"],
                                                "{}_to_{}_best_model_acc-{}_{}.pth.tar".format(args.source, args.target,
                                                                                               best_acc, d)))
            log_str = "iter: {:05d}, precision: {:.5f}".format(i, temp_acc)
            config["out_file"].write(log_str + "\n")
            config["out_file"].flush()

            print(log_str)
        if i % config["snapshot_interval"] == 0:
            torch.save(nn.Sequential(base_network), osp.join(config["output_path"], \
                                                             "{}_to_{}_iter_{:05d}_model_{}.pth.tar".format(args.source,
                                                                                                            args.target,
                                                                                                            i, str(
                                                                     datetime.datetime.utcnow()))))
        print("it_train: {:05d} / {:05d} start".format(i, config["num_iterations"]))
        loss_params = config["loss"]
        ## train one iter
        classifier1.train(True)
        base_network.train(True)
        ad_net.train(True)
        optimizer = lr_scheduler(optimizer, i, **schedule_param)
        optimizer.zero_grad()


        if i % len_train_source == 0:
            iter_source = iter(dset_loaders["source"])
        if i % len_train_target == 0:
            iter_target = iter(dset_loaders["target"])
        inputs_source, labels_source = iter_source.next()
        inputs_target, labels_target = iter_target.next()
        # inputs_source, inputs_target, labels_source = inputs_source.cuda(), inputs_target.cuda(), labels_source.cuda()

        # 提取特征
        features_source, outputs_source = base_network(inputs_source)
        features_target, outputs_target = base_network(inputs_target)
        features = torch.cat((features_source, features_target), dim=0)
        outputs = torch.cat((outputs_source, outputs_target), dim=0)
        softmax_out = nn.Softmax(dim=1)(outputs)

        outputs_source1 = classifier1(features_source.detach())
        outputs_target1 = classifier1(features_target.detach())
        outputs1 = torch.cat((outputs_source1,outputs_target1),dim=0)
        softmax_out1 = nn.Softmax(dim=1)(outputs1)

        softmax_out = (1-args.cla_plus_weight)*softmax_out + args.cla_plus_weight*softmax_out1

        if config['method'] == 'CDAN+E':
            entropy = loss.Entropy(softmax_out)
            transfer_loss = loss.CDAN([features, softmax_out], ad_net, entropy, network.calc_coeff(i), random_layer)
        elif config['method'] == 'CDAN':
            transfer_loss = loss.CDAN([features, softmax_out], ad_net, None, None, random_layer)
        elif config['method'] == 'DANN':
            transfer_loss = loss.DANN(features, ad_net)
        else:
            raise ValueError('Method cannot be recognized.')
        classifier_loss = nn.CrossEntropyLoss()(outputs_source, labels_source)

        # Cycle
        num_feature = features_source.size(0)
        # =================train discriminator T
        real_label = Variable(torch.ones(num_feature))
        # real_label = Variable(torch.ones(num_feature)).cuda()
        fake_label = Variable(torch.zeros(num_feature))
        # fake_label = Variable(torch.zeros(num_feature)).cuda()

        # 训练生成器
        optimizer_G.zero_grad()

        # Identity loss
        same_t = G_s2t(features_target.detach())
        loss_identity_t = criterion_identity(same_t, features_target)

        same_s = G_t2s(features_source.detach())
        loss_identity_s = criterion_identity(same_s, features_source)

        # Gan loss
        fake_t = G_s2t(features_source.detach())
        pred_fake = D_t(fake_t)
        loss_G_s2t = criterion_GAN(pred_fake, labels_source.float())

        fake_s = G_t2s(features_target.detach())
        pred_fake = D_s(fake_s)
        loss_G_t2s = criterion_GAN(pred_fake, labels_source.float())

        # cycle loss
        recovered_s = G_t2s(fake_t)
        loss_cycle_sts = criterion_cycle(recovered_s, features_source)

        recovered_t = G_s2t(fake_s)
        loss_cycle_tst = criterion_cycle(recovered_t, features_target)

        # sem loss
        pred_recovered_s = base_network.fc(recovered_s)
        pred_fake_t = base_network.fc(fake_t)
        loss_sem_t2s = criterion_Sem(pred_recovered_s, pred_fake_t)

        pred_recovered_t = base_network.fc(recovered_t)
        pred_fake_s = base_network.fc(fake_s)
        loss_sem_s2t = criterion_Sem(pred_recovered_t, pred_fake_s)

        loss_cycle = loss_cycle_tst + loss_cycle_sts
        weights = args.weight_in_lossG.split(',')
        loss_G = float(weights[0]) * (loss_identity_s + loss_identity_t) + \
                 float(weights[1]) * (loss_G_s2t + loss_G_t2s) + \
                 float(weights[2]) * loss_cycle + \
                 float(weights[3]) * (loss_sem_s2t + loss_sem_t2s)



        # 训练softmax分类器
        outputs_fake = classifier1(fake_t.detach())
        # 分类器优化
        classifier_loss1 = nn.CrossEntropyLoss()(outputs_fake, labels_source)
        classifier1_optim.zero_grad()
        classifier_loss1.backward()
        classifier1_optim.step()

        total_loss = loss_params["trade_off"] * transfer_loss + classifier_loss + args.cyc_loss_weight*loss_G
        total_loss.backward()
        optimizer.step()
        optimizer_G.step()

        ###### Discriminator S ######
        optimizer_D_s.zero_grad()

        # Real loss
        pred_real = D_s(features_source.detach())
        loss_D_real = criterion_GAN(pred_real, real_label)

        # Fake loss
        fake_s = fake_S_buffer.push_and_pop(fake_s)
        pred_fake = D_s(fake_s.detach())
        loss_D_fake = criterion_GAN(pred_fake, fake_label)

        # Total loss
        loss_D_s = loss_D_real + loss_D_fake
        loss_D_s.backward()

        optimizer_D_s.step()
        ###################################

        ###### Discriminator t ######
        optimizer_D_t.zero_grad()

        # Real loss
        pred_real = D_t(features_target.detach())
        loss_D_real = criterion_GAN(pred_real, real_label)

        # Fake loss
        fake_t = fake_T_buffer.push_and_pop(fake_t)
        pred_fake = D_t(fake_t.detach())
        loss_D_fake = criterion_GAN(pred_fake, fake_label)

        # Total loss
        loss_D_t = loss_D_real + loss_D_fake
        loss_D_t.backward()
        optimizer_D_t.step()
        print("it_train: {:05d} / {:05d} over".format(i, config["num_iterations"]))
    now = datetime.datetime.now()
    d = str(now.month)+'-'+str(now.day)+' '+str(now.hour)+':'+str(now.minute)+":"+str(now.second)
    torch.save(best_model, osp.join(config["output_path"],
                                    "{}_to_{}_best_model_acc-{}_{}.pth.tar".format(args.source, args.target,
                                                                            best_acc,d)))
    return best_acc
Пример #17
0
def train(config):
    ## set pre-process
    prep_dict = {}
    prep_config = config["prep"]
    prep_dict["source"] = prep.image_train(**config["prep"]['params'])
    prep_dict["target"] = prep.image_train(**config["prep"]['params'])
    if prep_config["test_10crop"]:
        prep_dict["test"] = prep.image_test_10crop(**config["prep"]['params'])
    else:
        prep_dict["test"] = prep.image_test(**config["prep"]['params'])

    ## prepare data
    dsets = {}
    dset_loaders = {}
    data_config = config["data"]
    train_bs = data_config["source"]["batch_size"]
    test_bs = data_config["test"]["batch_size"]
    dsets["source"] = ImageList(open(data_config["source"]["list_path"]).readlines(), \
                                transform=prep_dict["source"])
    dset_loaders["source"] = DataLoader(dsets["source"], batch_size=train_bs, \
                                        shuffle=True, num_workers=0, drop_last=True)
    dsets["target"] = ImageList(open(data_config["target"]["list_path"]).readlines(), \
                                transform=prep_dict["target"])
    dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, \
                                        shuffle=True, num_workers=0, drop_last=True)

    if prep_config["test_10crop"]:
        for i in range(10):
            dsets["test"] = [ImageList(open(data_config["test"]["list_path"]).readlines(), \
                                       transform=prep_dict["test"][i]) for i in range(10)]
            dset_loaders["test"] = [DataLoader(dset, batch_size=test_bs, \
                                               shuffle=False, num_workers=0) for dset in dsets['test']]
    else:
        dsets["test"] = ImageList(open(data_config["test"]["list_path"]).readlines(), \
                                  transform=prep_dict["test"])
        dset_loaders["test"] = DataLoader(dsets["test"], batch_size=test_bs, \
                                          shuffle=False, num_workers=0)

    class_num = config["network"]["params"]["class_num"]

    ## set base network
    net_config = config["network"]
    base_network = net_config["name"](**net_config["params"])
    base_network = base_network.cuda()

    ## add additional network for some methods
    if config["loss"]["random"]:
        random_layer = network.RandomLayer(
            [base_network.output_num(), class_num],
            config["loss"]["random_dim"])
        ad_net = network.AdversarialNetwork(config["loss"]["random_dim"], 1024)
    else:
        random_layer = None
        ad_net = network.AdversarialNetwork(
            base_network.output_num() * class_num, 1024)
    if config["loss"]["random"]:
        random_layer.cuda()
    ad_net = ad_net.cuda()
    parameter_list = base_network.get_parameters() + ad_net.get_parameters()

    ## set optimizer
    optimizer_config = config["optimizer"]
    optimizer = optimizer_config["type"](parameter_list, \
                                         **(optimizer_config["optim_params"]))
    param_lr = []
    for param_group in optimizer.param_groups:
        param_lr.append(param_group["lr"])
    schedule_param = optimizer_config["lr_param"]
    lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]]

    gpus = config['gpu'].split(',')
    if len(gpus) > 1:
        ad_net = nn.DataParallel(ad_net, device_ids=[int(i) for i in gpus])
        base_network = nn.DataParallel(base_network,
                                       device_ids=[int(i) for i in gpus])

    ## train
    len_train_source = len(dset_loaders["source"])
    len_train_target = len(dset_loaders["target"])
    best_acc = 0.0
    best_model = nn.Sequential(base_network)
    each_log = ""
    for i in range(config["num_iterations"]):
        if i % config["test_interval"] == config["test_interval"] - 1:

            base_network.train(False)
            temp_acc = image_classification_test(dset_loaders, \
                                                 base_network, test_10crop=prep_config["test_10crop"])
            temp_model = nn.Sequential(base_network)
            if temp_acc > best_acc:
                best_acc = temp_acc
                best_model = temp_model
            log_str = "iter: {:05d}, precision: {:.5f}, transfer_loss:{:.4f}, classifier_loss:{:.4f}, total_loss:{:.4f}" \
                .format(i, temp_acc, transfer_loss.item(), classifier_loss.item(), total_loss.item())
            config["out_file"].write(log_str + "\n")
            config["out_file"].flush()
            print(log_str)

            config["out_file"].write(each_log)
            config["out_file"].flush()
            each_log = ""
        loss_params = config["loss"]
        ## train one iter
        base_network.train(True)
        ad_net.train(True)
        optimizer = lr_scheduler(optimizer, i, **schedule_param)
        optimizer.zero_grad()
        if i % len_train_source == 0:
            iter_source = iter(dset_loaders["source"])
        if i % len_train_target == 0:
            iter_target = iter(dset_loaders["target"])
        inputs_source, labels_source = iter_source.next()
        inputs_target, labels_target = iter_target.next()
        inputs_source, inputs_target, labels_source = inputs_source.cuda(
        ), inputs_target.cuda(), labels_source.cuda()
        features_source, outputs_source = base_network(inputs_source)
        features_target, outputs_target = base_network(inputs_target)
        features = torch.cat((features_source, features_target), dim=0)
        outputs = torch.cat((outputs_source, outputs_target), dim=0)
        softmax_out = nn.Softmax(dim=1)(outputs)
        labels_target_fake = torch.max(nn.Softmax(dim=1)(outputs_target), 1)[1]
        labels = torch.cat((labels_source, labels_target_fake))
        entropy = loss.Entropy(softmax_out)
        transfer_loss = loss.CDAN([features, softmax_out], ad_net, entropy,
                                  network.calc_coeff(i), random_layer)

        classifier_loss = nn.CrossEntropyLoss()(outputs_source, labels_source)
        mdd_loss = loss.mdd_loss(features=features,
                                 labels=labels,
                                 left_weight=args.left_weight,
                                 right_weight=args.right_weight)
        max_entropy_loss = loss.EntropicConfusion(features)
        total_loss = loss_params["trade_off"] * transfer_loss \
                     + args.cls_weight * classifier_loss \
                     + args.mdd_weight * mdd_loss \
                     + args.entropic_weight * max_entropy_loss
        total_loss.backward()
        optimizer.step()
        log_str = "iter: {:05d},transfer_loss:{:.4f}, classifier_loss:{:.4f}, mdd_loss:{:4f}," \
                  "max_entropy_loss:{:.4f},total_loss:{:.4f}" \
            .format(i, transfer_loss.item(), classifier_loss.item(), mdd_loss.item(),
                    max_entropy_loss.item(), total_loss.item())
        each_log += log_str + "\n"

    torch.save(
        best_model, config['model_output_path'] + "{}_{}_p-{}_e-{}".format(
            config['log_name'], str(best_acc), str(config["mdd_weight"]),
            str(config["entropic_weight"])))
    return best_acc
Пример #18
0
def train(config):
    ## set pre-process
    prep_dict = {}
    prep_config = config["prep"]
    prep_dict["source"] = prep.image_train(**config["prep"]['params'])
    prep_dict["target"] = prep.image_train(**config["prep"]['params'])
    if prep_config["test_10crop"]:
        prep_dict["test"] = prep.image_test_10crop(**config["prep"]['params'])
    else:
        prep_dict["test"] = prep.image_test(**config["prep"]['params'])

    ## prepare data
    dsets = {}
    dset_loaders = {}
    data_config = config["data"]
    train_bs = data_config["source"]["batch_size"]
    test_bs = data_config["test"]["batch_size"]
    dsets["source"] = ImageList(open(data_config["source"]["list_path"]).readlines(), \
                                transform=prep_dict["source"])
    dset_loaders["source"] = DataLoader(dsets["source"], batch_size=train_bs, \
            shuffle=True, num_workers=4, drop_last=True)
    dsets["target"] = ImageList(open(data_config["target"]["list_path"]).readlines(), \
                                transform=prep_dict["target"])
    dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, \
            shuffle=True, num_workers=4, drop_last=True)

    #     if prep_config["test_10crop"]:
    #         for i in range(10):
    #             dsets["test"] = [ImageList(open(data_config["test"]["list_path"]).readlines(), \
    #                                 transform=prep_dict["test"][i]) for i in range(10)]
    #             dset_loaders["test"] = [DataLoader(dset, batch_size=test_bs, \
    #                                 shuffle=False, num_workers=4) for dset in dsets['test']]
    #     else:
    #         dsets["test"] = ImageList(open(data_config["test"]["list_path"]).readlines(), \
    #                                 transform=prep_dict["test"])
    #         dset_loaders["test"] = DataLoader(dsets["test"], batch_size=test_bs, \
    #                                 shuffle=False, num_workers=4)

    class_num = config["network"]["params"]["class_num"]

    ## set base network
    net_config = config["network"]
    base_network = net_config["name"](**net_config["params"])
    base_network = base_network.cuda()

    ## add additional network for some methods
    if config["loss"]["random"]:
        random_layer = network.RandomLayer(
            [base_network.output_num(), class_num],
            config["loss"]["random_dim"])
        ad_net = network.AdversarialNetwork(config["loss"]["random_dim"], 1024)
    else:
        random_layer = None
        ad_net = network.AdversarialNetwork(
            base_network.output_num() * class_num, 1024)
    if config["loss"]["random"]:
        random_layer.cuda()
    ad_net = ad_net.cuda()
    parameter_list = base_network.get_parameters() + ad_net.get_parameters()

    ## set optimizer
    optimizer_config = config["optimizer"]
    optimizer = optimizer_config["type"](parameter_list, \
                    **(optimizer_config["optim_params"]))
    param_lr = []
    for param_group in optimizer.param_groups:
        param_lr.append(param_group["lr"])
    schedule_param = optimizer_config["lr_param"]
    lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]]

    gpus = config['gpu'].split(',')
    if len(gpus) > 1:
        ad_net = nn.DataParallel(ad_net, device_ids=[int(i) for i in gpus])
        base_network = nn.DataParallel(base_network,
                                       device_ids=[int(i) for i in gpus])

    ## train
    len_train_source = len(dset_loaders["source"])
    len_train_target = len(dset_loaders["target"])
    transfer_loss_value = classifier_loss_value = total_loss_value = 0.0
    best_acc = 0.0
    for i in range(config["num_iterations"]):
        #         if i % config["test_interval"] == config["test_interval"] - 1:
        #             base_network.train(False)
        #             temp_acc = image_classification_test(dset_loaders, \
        #                 base_network, test_10crop=prep_config["test_10crop"])
        #             temp_model = nn.Sequential(base_network)
        #             if temp_acc > best_acc:
        #                 best_acc = temp_acc
        #                 best_model = temp_model
        #             log_str = "iter: {:05d}, precision: {:.5f}".format(i, temp_acc)
        #             config["out_file"].write(log_str+"\n")
        #             config["out_file"].flush()
        #             print(log_str)
        if i % config["snapshot_interval"] == 0:
            torch.save(nn.Sequential(base_network), osp.join(config["output_path"], \
                "iter_{:05d}_model.pth.tar".format(i)))

        loss_params = config["loss"]
        ## train one iter
        base_network.train(True)
        ad_net.train(True)
        optimizer = lr_scheduler(optimizer, i, **schedule_param)
        optimizer.zero_grad()
        if i % len_train_source == 0:
            iter_source = iter(dset_loaders["source"])
        if i % len_train_target == 0:
            iter_target = iter(dset_loaders["target"])
        inputs_source, labels_source = iter_source.next()
        inputs_target, labels_target = iter_target.next()
        inputs_source, inputs_target, labels_source = inputs_source.cuda(
        ), inputs_target.cuda(), labels_source.cuda()
        features_source, outputs_source = base_network(inputs_source)
        features_target, outputs_target = base_network(inputs_target)
        features = torch.cat((features_source, features_target), dim=0)
        outputs = torch.cat((outputs_source, outputs_target), dim=0)
        softmax_out = nn.Softmax(dim=1)(outputs)
        if config['method'] == 'CDAN+E':
            entropy = loss.Entropy(softmax_out)
            transfer_loss = loss.CDAN([features, softmax_out], ad_net, entropy,
                                      network.calc_coeff(i), random_layer)
        elif config['method'] == 'CDAN':
            transfer_loss = loss.CDAN([features, softmax_out], ad_net, None,
                                      None, random_layer)
        elif config['method'] == 'DANN':
            transfer_loss = loss.DANN(features, ad_net)
        else:
            raise ValueError('Method cannot be recognized.')
        classifier_loss = nn.CrossEntropyLoss()(outputs_source, labels_source)
        if i % 10 == 0:
            print('iter: ', i, 'classifier_loss: ', classifier_loss.data,
                  'transfer_loss: ', transfer_loss.data)
        total_loss = loss_params["trade_off"] * transfer_loss + classifier_loss
        total_loss.backward()
        optimizer.step()
    torch.save(best_model, osp.join(config["output_path"],
                                    "best_model.pth.tar"))
    return best_acc
Пример #19
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='CDAN USPS MNIST')
    parser.add_argument('--method', type=str, default='DANN')
    parser.add_argument('--task', default='MNIST2USPS', help='task to perform')
    parser.add_argument('--batch_size',
                        type=int,
                        default=64,
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test_batch_size',
                        type=int,
                        default=1000,
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=20,
                        metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.01,
                        metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.5,
                        metavar='M',
                        help='SGD momentum (default: 0.5)')
    parser.add_argument('--gpu_id',
                        type=str,
                        default='0',
                        help='cuda device id')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument(
        '--log_interval',
        type=int,
        default=10,
        help='how many batches to wait before logging training status')
    parser.add_argument('--random',
                        type=bool,
                        default=False,
                        help='whether to use random')
    parser.add_argument('--weight',
                        type=int,
                        default=0,
                        help="whether use weights during transfer")
    parser.add_argument('--temp_max',
                        type=float,
                        default=5.,
                        help="weight relaxation parameter")
    parser.add_argument('--alpha',
                        type=float,
                        default=5.,
                        help="weight relaxation parameter")

    args = parser.parse_args()

    torch.manual_seed(args.seed)
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id

    if args.task == 'USPS2MNIST':
        source_list = '../data/usps2mnist/usps_train_shift_20.txt'
        target_list = '../data/usps2mnist/mnist_train.txt'
        test_list = '../data/usps2mnist/mnist_test.txt'
        start_epoch = 1
        decay_epoch = 6
    elif args.task == 'MNIST2USPS':
        source_list = '../data/usps2mnist/mnist_train_shift_20.txt'
        target_list = '../data/usps2mnist/usps_train.txt'
        test_list = '../data/usps2mnist/usps_test.txt'
        start_epoch = 1
        decay_epoch = 5
    else:
        raise Exception('task cannot be recognized!')

    train_loader = torch.utils.data.DataLoader(ImageList(
        open(source_list).readlines(),
        transform=transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, ), (0.5, ))
        ]),
        mode='L'),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=1,
                                               drop_last=True)
    train_loader1 = torch.utils.data.DataLoader(ImageList(
        open(target_list).readlines(),
        transform=transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, ), (0.5, ))
        ]),
        mode='L'),
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                num_workers=1,
                                                drop_last=True)
    test_loader = torch.utils.data.DataLoader(ImageList(
        open(test_list).readlines(),
        transform=transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, ), (0.5, ))
        ]),
        mode='L'),
                                              batch_size=args.test_batch_size,
                                              shuffle=True,
                                              num_workers=1)

    model = network.LeNet()
    model = model.cuda()
    class_num = 10

    if args.method is 'DANN':
        random_layer = None
        ad_net = network.AdversarialNetwork(model.output_num(), 500)

    elif args.method is 'Y_DAN':
        random_layer = None
        ad_net = network.AdversarialNetwork(model.output_num(),
                                            500,
                                            output_dim=class_num)

    elif args.method is 'CDAN':
        if args.random:
            random_layer = network.RandomLayer([model.output_num(), class_num],
                                               500)
            random_layer.cuda()
        else:
            random_layer = None
        ad_net = network.AdversarialNetwork(500 * 10, 500)
    else:
        raise ValueError('Method cannot be recognized.')

    ad_w_net = network.AdversarialNetwork(model.output_num(),
                                          500,
                                          output_dim=1)

    ad_net = ad_net.cuda()
    ad_w_net = ad_w_net.cuda()

    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          weight_decay=0.0005,
                          momentum=0.9)
    optimizer_ad = optim.SGD(ad_net.parameters(),
                             lr=args.lr,
                             weight_decay=0.0005,
                             momentum=0.9)
    optimizer_ad_w = optim.SGD(ad_w_net.parameters(),
                               lr=args.lr,
                               weight_decay=0.0005,
                               momentum=0.9)

    for epoch in range(1, args.epochs + 1):
        print("epoch", epoch)
        if epoch % decay_epoch == 0:
            for param_group in optimizer.param_groups:
                param_group["lr"] = param_group["lr"] * 0.5
        train(args,
              epoch,
              model,
              ad_net,
              ad_w_net,
              train_loader,
              train_loader1,
              optimizer,
              optimizer_ad,
              optimizer_ad_w,
              epoch,
              start_epoch,
              args.method,
              random_layer=random_layer)
        test(args, model, test_loader)