Exemplo n.º 1
0
def show_me_the_graphs(config):

	if config["network"] == "DeepMerge":
		config["network"] = {"name":network.DeepMerge, "params":{"class_num":2, "new_cls":True, "use_bottleneck":False, "bottleneck_dim":32*9*9}}
	elif config["network"] == "Res18":
	    config["network"] = {"name":network.ResNetFc, "params":{"class_num":2, "resnet_name": "ResNet18", "use_bottleneck":True, "bottleneck_dim":256, "new_cls":True}}

	use_gpu = torch.cuda.is_available()
	net_config = config["network"]
	base_network = net_config["name"](**net_config["params"])
	ad_net = network.AdversarialNetwork(base_network.output_num())
	#gradient_reverse_layer = network.AdversarialLayer(high_value = config["high"])

	if use_gpu:
		base_network = base_network.cuda()
		ad_net = ad_net.cuda()

	## prepare data
	dsets = {}
	dset_loaders = {}

	pristine_indices = torch.randperm(len(pristine_x))
	pristine_x_train = pristine_x[pristine_indices[:int(np.floor(.7*len(pristine_x)))]]
	pristine_y_train = pristine_y[pristine_indices[:int(np.floor(.7*len(pristine_x)))]]

	noisy_indices = torch.randperm(len(noisy_x))
	noisy_x_train = noisy_x[noisy_indices[:int(np.floor(.7*len(noisy_x)))]]
	noisy_y_train = noisy_y[noisy_indices[:int(np.floor(.7*len(noisy_x)))]]

	dsets["source"] = TensorDataset(pristine_x_train, pristine_y_train)
	dsets["target"] = TensorDataset(noisy_x_train, noisy_y_train)

	dset_loaders["source"] = DataLoader(dsets["source"], batch_size =128, shuffle = True, num_workers = 1)
	dset_loaders["target"] = DataLoader(dsets["target"], batch_size = 128, shuffle = True, num_workers = 1)

	#give a dummy batch, except wait, features are important for ad_net
	inputs_source, labels_source = iter(dset_loaders["source"]).next()
	inputs_target, labels_target = iter(dset_loaders["target"]).next()

	if use_gpu:
		source_batch = Variable(inputs_source).cuda()
		target_batch = Variable(inputs_target).cuda()
	else:
		source_batch = Variable(inputs_source)
		target_batch = Variable(inputs_target)

	inputs = torch.cat((source_batch, target_batch), dim=0)
	weight_ad = torch.ones(inputs.size(0))

	features, base_logits = base_network(inputs)
	yhat_ad = ad_net(features.detach())

	input_names_base = ['Galaxy Array']
	output_names_base = ['Merger or Not']

	input_names_adnet = ['Base Network Features']
	output_names_adnet = ['Source or Target Domain']

	torch.onnx.export(base_network, inputs, 'base.onnx', input_names=input_names_base, output_names=output_names_base)
	torch.onnx.export(ad_net, features, 'ad_net.onnx', input_names=input_names_adnet, output_names=output_names_adnet)
Exemplo n.º 2
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='CDAN SVHN MNIST')
    parser.add_argument('--method',
                        type=str,
                        default='CDAN-E',
                        choices=['CDAN', 'CDAN-E', 'DANN'])
    parser.add_argument('--task', default='USPS2MNIST', help='task to perform')
    parser.add_argument('--batch_size',
                        type=int,
                        default=64,
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test_batch_size',
                        type=int,
                        default=1000,
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=10,
                        metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--lr', type=float, default=0.03, metavar='LR')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.5,
                        metavar='M',
                        help='SGD momentum (default: 0.5)')
    parser.add_argument('--gpu_id', type=str, help='cuda device id')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument(
        '--log_interval',
        type=int,
        default=10,
        help='how many batches to wait before logging training status')
    parser.add_argument('--random',
                        type=bool,
                        default=False,
                        help='whether to use random')
    args = parser.parse_args()

    torch.manual_seed(args.seed)
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id

    source_list = '../data/svhn2mnist/svhn_balanced.txt'
    target_list = '../data/svhn2mnist/mnist_train.txt'
    test_list = '../data/svhn2mnist/mnist_test.txt'

    train_loader = torch.utils.data.DataLoader(ImageList(
        open(source_list).readlines(),
        transform=transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.5, ), (0.5, ))]),
        mode='RGB'),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=1)
    train_loader1 = torch.utils.data.DataLoader(ImageList(
        open(target_list).readlines(),
        transform=transforms.Compose([
            transforms.Resize((32, 32)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, ), (0.5, ))
        ]),
        mode='RGB'),
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                num_workers=1)
    test_loader = torch.utils.data.DataLoader(ImageList(
        open(test_list).readlines(),
        transform=transforms.Compose([
            transforms.Resize((32, 32)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, ), (0.5, ))
        ]),
        mode='RGB'),
                                              batch_size=args.test_batch_size,
                                              shuffle=True,
                                              num_workers=1)

    model = network.DTN()
    model = model.cuda()
    class_num = 10

    if args.random:
        random_layer = network.RandomLayer([model.output_num(), class_num],
                                           500)
        ad_net = network.AdversarialNetwork(500, 500)
        random_layer.cuda()
    else:
        random_layer = None
        print('aaa:', model.output_num(), class_num)
        ad_net = network.AdversarialNetwork(model.output_num() * class_num,
                                            500)
    ad_net = ad_net.cuda()
    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          weight_decay=0.0005,
                          momentum=0.9)
    optimizer_ad = optim.SGD(ad_net.parameters(),
                             lr=args.lr,
                             weight_decay=0.0005,
                             momentum=0.9)

    for epoch in range(1, args.epochs + 1):
        if epoch % 3 == 0:
            for param_group in optimizer.param_groups:
                param_group["lr"] = param_group["lr"] * 0.3
        train(args, model, ad_net, random_layer, train_loader, train_loader1,
              optimizer, optimizer_ad, epoch, 0, args.method)
        acc = test(args, model, test_loader)
        with summary_writer.as_default():
            tf.summary.scalar("acc", acc, step=epoch)
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='CDAN USPS MNIST')
    parser.add_argument('--method', type=str, default='CDAN-E', choices=['CDAN', 'CDAN-E', 'DANN'])
    parser.add_argument('--task', default='USPS2MNIST', help='task to perform')
    parser.add_argument('--batch_size', type=int, default=256, help='input batch size for training (default: 64)')
    parser.add_argument('--test_batch_size', type=int, default=1000, help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=10, metavar='N', help='number of epochs to train (default: 10)')
    parser.add_argument('--lr', type=float, default=0.01, metavar='LR', help='learning rate (default: 0.01)')
    parser.add_argument('--momentum', type=float, default=0.5, metavar='M', help='SGD momentum (default: 0.5)')
    parser.add_argument('--gpu_id', type=str,default="0",  help='cuda device id')
    parser.add_argument('--seed', type=int, default=1, metavar='S',  help='random seed (default: 1)')
    parser.add_argument('--log_interval', type=int, default=10, help='how many batches to wait before logging training status')
    parser.add_argument('--random', type=bool, default=False, help='whether to use random')
    parser.add_argument('--output_dir',type=str,default="digits/u2m")
    parser.add_argument('--cla_plus_weight',type=float,default=0.3)
    parser.add_argument('--cyc_loss_weight',type=float,default=0.01)
    parser.add_argument('--weight_in_loss_g',type=str,default='1,0.01,0.1,0.1')
    args = parser.parse_args()

    torch.manual_seed(args.seed)
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id

    # train config
    import os.path as osp
    import datetime
    config = {}

    config['method'] = args.method
    config["gpu"] = args.gpu_id
    config['cyc_loss_weight'] = args.cyc_loss_weight
    config['cla_plus_weight'] = args.cla_plus_weight
    config['weight_in_loss_g'] = args.weight_in_loss_g
    config["epochs"] = args.epochs
    config["output_for_test"] = True
    config["output_path"] = "snapshot/" + args.output_dir
    if not osp.exists(config["output_path"]):
        os.system('mkdir -p ' + config["output_path"])
    config["out_file"] = open(osp.join(config["output_path"], "log_{}_{}.txt".
                                       format(args.task,str(datetime.datetime.utcnow()))),
                              "w")

    config["out_file"].write(str(config))
    config["out_file"].flush()


    source_list = '/data/usps/usps_train.txt'
    target_list = '/data/mnist/mnist_train.txt'
    test_list = '/data/mnist/mnist_test.txt'
    start_epoch = 1
    decay_epoch = 6


    train_loader = torch.utils.data.DataLoader(
        ImageList(open(source_list).readlines(), transform=transforms.Compose([
                           transforms.Resize((28,28)),
                           transforms.ToTensor(),
                           transforms.Normalize((0.5,), (0.5,))
                       ]), mode='L'),
        batch_size=args.batch_size, shuffle=True, num_workers=1, drop_last=True)
    train_loader1 = torch.utils.data.DataLoader(
        ImageList(open(target_list).readlines(), transform=transforms.Compose([
                           transforms.Resize((28,28)),
                           transforms.ToTensor(),
                           transforms.Normalize((0.5,), (0.5,))
                       ]), mode='L'),
        batch_size=args.batch_size, shuffle=True, num_workers=1, drop_last=True)
    test_loader = torch.utils.data.DataLoader(
        ImageList(open(test_list).readlines(), transform=transforms.Compose([
                           transforms.Resize((28,28)),
                           transforms.ToTensor(),
                           transforms.Normalize((0.5,), (0.5,))
                       ]), mode='L'),
        batch_size=args.test_batch_size, shuffle=True, num_workers=1)

    model = network.LeNet()
    # model = model.cuda()
    class_num = 10

    # 添加G,D,和额外的分类器
    import itertools
    from utils import ReplayBuffer
    import net
    z_dimension = 500
    D_s = network.models["Discriminator_um"]()
    # D_s = D_s.cuda()
    G_s2t = network.models["Generator_um"](z_dimension, 500)
    # G_s2t = G_s2t.cuda()

    D_t = network.models["Discriminator_um"]()
    # D_t = D_t.cuda()
    G_t2s = network.models["Generator_um"](z_dimension, 500)
    # G_t2s = G_t2s.cuda()

    criterion_GAN = torch.nn.MSELoss()
    criterion_cycle = torch.nn.L1Loss()
    criterion_identity = torch.nn.L1Loss()
    criterion_Sem = torch.nn.L1Loss()

    optimizer_G = torch.optim.Adam(itertools.chain(G_s2t.parameters(), G_t2s.parameters()), lr=0.0003)
    optimizer_D_s = torch.optim.Adam(D_s.parameters(), lr=0.0003)
    optimizer_D_t = torch.optim.Adam(D_t.parameters(), lr=0.0003)

    fake_S_buffer = ReplayBuffer()
    fake_T_buffer = ReplayBuffer()

    ## 添加分类器
    classifier1 = net.Net(500, class_num)
    # classifier1 = classifier1.cuda()
    classifier1_optim = optim.Adam(classifier1.parameters(), lr=0.0003)


    if args.random:
        random_layer = network.RandomLayer([model.output_num(), class_num], 500)
        ad_net = network.AdversarialNetwork(500, 500)
        # random_layer.cuda()
    else:
        random_layer = None
        ad_net = network.AdversarialNetwork(model.output_num() * class_num, 500)
    # ad_net = ad_net.cuda()
    optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=0.0005, momentum=0.9)
    optimizer_ad = optim.SGD(ad_net.parameters(), lr=args.lr, weight_decay=0.0005, momentum=0.9)

    for epoch in range(1, args.epochs + 1):
        if epoch % decay_epoch == 0:
            for param_group in optimizer.param_groups:
                param_group["lr"] = param_group["lr"] * 0.5
        train(args, model, ad_net, random_layer, train_loader, train_loader1, optimizer, optimizer_ad, epoch, start_epoch, args.method,
              D_s, D_t, G_s2t, G_t2s, criterion_Sem, criterion_GAN, criterion_cycle, criterion_identity, optimizer_G,
              optimizer_D_t, optimizer_D_s,
              classifier1, classifier1_optim, fake_S_buffer, fake_T_buffer
              )
        test(args,epoch,config, model, test_loader)
Exemplo n.º 4
0
def train(config):
    ## set pre-process
    prep_dict = {}
    prep_config = config["prep"]
    prep_dict["source"] = prep.image_train(**config["prep"]['params'])
    prep_dict["target"] = prep.image_train(**config["prep"]['params'])
    if prep_config["test_10crop"]:
        prep_dict["test"] = prep.image_test_10crop(**config["prep"]['params'])
    else:
        prep_dict["test"] = prep.image_test(**config["prep"]['params'])

    ## prepare data
    dsets = {}
    dset_loaders = {}
    data_config = config["data"]
    train_bs = data_config["source"]["batch_size"]
    test_bs = data_config["test"]["batch_size"]
    dsets["source"] = ImageList(open(data_config["source"]["list_path"]).readlines(), \
                                transform=prep_dict["source"])
    dset_loaders["source"] = DataLoader(dsets["source"], batch_size=train_bs, \
                                        shuffle=True, num_workers=0, drop_last=True)
    dsets["target"] = ImageList(open(data_config["target"]["list_path"]).readlines(), \
                                transform=prep_dict["target"])
    dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, \
                                        shuffle=True, num_workers=0, drop_last=True)

    if prep_config["test_10crop"]:
        for i in range(10):
            dsets["test"] = [ImageList(open(data_config["test"]["list_path"]).readlines(), \
                                       transform=prep_dict["test"][i]) for i in range(10)]
            dset_loaders["test"] = [DataLoader(dset, batch_size=test_bs, \
                                               shuffle=False, num_workers=0) for dset in dsets['test']]
    else:
        dsets["test"] = ImageList(open(data_config["test"]["list_path"]).readlines(), \
                                  transform=prep_dict["test"])
        dset_loaders["test"] = DataLoader(dsets["test"], batch_size=test_bs, \
                                          shuffle=False, num_workers=0)

    class_num = config["network"]["params"]["class_num"]

    ## set base network
    net_config = config["network"]
    base_network = net_config["name"](**net_config["params"])
    # base_network = base_network.cuda()

    ## 添加判别器D_s,D_t,生成器G_s2t,G_t2s

    z_dimension = 256
    D_s = network.models["Discriminator"]()
    # D_s = D_s.cuda()
    G_s2t = network.models["Generator"](z_dimension, 1024)
    # G_s2t = G_s2t.cuda()

    D_t = network.models["Discriminator"]()
    # D_t = D_t.cuda()
    G_t2s = network.models["Generator"](z_dimension, 1024)
    # G_t2s = G_t2s.cuda()

    criterion_GAN = torch.nn.MSELoss()
    criterion_cycle = torch.nn.L1Loss()
    criterion_identity = torch.nn.L1Loss()
    criterion_Sem = torch.nn.L1Loss()

    optimizer_G = torch.optim.Adam(itertools.chain(G_s2t.parameters(), G_t2s.parameters()), lr=0.0003)
    optimizer_D_s = torch.optim.Adam(D_s.parameters(), lr=0.0003)
    optimizer_D_t = torch.optim.Adam(D_t.parameters(), lr=0.0003)

    fake_S_buffer = ReplayBuffer()
    fake_T_buffer = ReplayBuffer()

    classifier_optimizer = torch.optim.Adam(base_network.parameters(), lr=0.0003)
    ## 添加分类器
    classifier1 = net.Net(256,class_num)
    # classifier1 = classifier1.cuda()
    classifier1_optim = optim.Adam(classifier1.parameters(), lr=0.0003)

    ## add additional network for some methods
    if config["loss"]["random"]:
        random_layer = network.RandomLayer([base_network.output_num(), class_num], config["loss"]["random_dim"])
        ad_net = network.AdversarialNetwork(config["loss"]["random_dim"], 1024)
    else:
        random_layer = None
        ad_net = network.AdversarialNetwork(base_network.output_num() * class_num, 1024)
    if config["loss"]["random"]:
        random_layer.cuda()
    # ad_net = ad_net.cuda()
    parameter_list = base_network.get_parameters() + ad_net.get_parameters()

    ## set optimizer
    optimizer_config = config["optimizer"]
    optimizer = optimizer_config["type"](parameter_list, \
                                         **(optimizer_config["optim_params"]))
    param_lr = []
    for param_group in optimizer.param_groups:
        param_lr.append(param_group["lr"])
    schedule_param = optimizer_config["lr_param"]
    lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]]

    gpus = config['gpu'].split(',')
    if len(gpus) > 1:
        ad_net = nn.DataParallel(ad_net, device_ids=[int(i) for i in gpus])
        base_network = nn.DataParallel(base_network, device_ids=[int(i) for i in gpus])

    ## train
    len_train_source = len(dset_loaders["source"])
    len_train_target = len(dset_loaders["target"])
    transfer_loss_value = classifier_loss_value = total_loss_value = 0.0
    best_acc = 0.0
    for i in range(config["num_iterations"]):
        if i % config["test_interval"] == config["test_interval"] - 1:
            base_network.train(False)
            temp_acc = image_classification_test(dset_loaders, \
                                                 base_network, test_10crop=prep_config["test_10crop"])
            temp_model = nn.Sequential(base_network)
            if temp_acc > best_acc:
                best_acc = temp_acc
                best_model = temp_model

                now = datetime.datetime.now()
                d = str(now.month) + '-' + str(now.day) + ' ' + str(now.hour) + ':' + str(now.minute) + ":" + str(
                    now.second)
                torch.save(best_model, osp.join(config["output_path"],
                                                "{}_to_{}_best_model_acc-{}_{}.pth.tar".format(args.source, args.target,
                                                                                               best_acc, d)))
            log_str = "iter: {:05d}, precision: {:.5f}".format(i, temp_acc)
            config["out_file"].write(log_str + "\n")
            config["out_file"].flush()

            print(log_str)
        if i % config["snapshot_interval"] == 0:
            torch.save(nn.Sequential(base_network), osp.join(config["output_path"], \
                                                             "{}_to_{}_iter_{:05d}_model_{}.pth.tar".format(args.source,
                                                                                                            args.target,
                                                                                                            i, str(
                                                                     datetime.datetime.utcnow()))))
        print("it_train: {:05d} / {:05d} start".format(i, config["num_iterations"]))
        loss_params = config["loss"]
        ## train one iter
        classifier1.train(True)
        base_network.train(True)
        ad_net.train(True)
        optimizer = lr_scheduler(optimizer, i, **schedule_param)
        optimizer.zero_grad()


        if i % len_train_source == 0:
            iter_source = iter(dset_loaders["source"])
        if i % len_train_target == 0:
            iter_target = iter(dset_loaders["target"])
        inputs_source, labels_source = iter_source.next()
        inputs_target, labels_target = iter_target.next()
        # inputs_source, inputs_target, labels_source = inputs_source.cuda(), inputs_target.cuda(), labels_source.cuda()

        # 提取特征
        features_source, outputs_source = base_network(inputs_source)
        features_target, outputs_target = base_network(inputs_target)
        features = torch.cat((features_source, features_target), dim=0)
        outputs = torch.cat((outputs_source, outputs_target), dim=0)
        softmax_out = nn.Softmax(dim=1)(outputs)

        outputs_source1 = classifier1(features_source.detach())
        outputs_target1 = classifier1(features_target.detach())
        outputs1 = torch.cat((outputs_source1,outputs_target1),dim=0)
        softmax_out1 = nn.Softmax(dim=1)(outputs1)

        softmax_out = (1-args.cla_plus_weight)*softmax_out + args.cla_plus_weight*softmax_out1

        if config['method'] == 'CDAN+E':
            entropy = loss.Entropy(softmax_out)
            transfer_loss = loss.CDAN([features, softmax_out], ad_net, entropy, network.calc_coeff(i), random_layer)
        elif config['method'] == 'CDAN':
            transfer_loss = loss.CDAN([features, softmax_out], ad_net, None, None, random_layer)
        elif config['method'] == 'DANN':
            transfer_loss = loss.DANN(features, ad_net)
        else:
            raise ValueError('Method cannot be recognized.')
        classifier_loss = nn.CrossEntropyLoss()(outputs_source, labels_source)

        # Cycle
        num_feature = features_source.size(0)
        # =================train discriminator T
        real_label = Variable(torch.ones(num_feature))
        # real_label = Variable(torch.ones(num_feature)).cuda()
        fake_label = Variable(torch.zeros(num_feature))
        # fake_label = Variable(torch.zeros(num_feature)).cuda()

        # 训练生成器
        optimizer_G.zero_grad()

        # Identity loss
        same_t = G_s2t(features_target.detach())
        loss_identity_t = criterion_identity(same_t, features_target)

        same_s = G_t2s(features_source.detach())
        loss_identity_s = criterion_identity(same_s, features_source)

        # Gan loss
        fake_t = G_s2t(features_source.detach())
        pred_fake = D_t(fake_t)
        loss_G_s2t = criterion_GAN(pred_fake, labels_source.float())

        fake_s = G_t2s(features_target.detach())
        pred_fake = D_s(fake_s)
        loss_G_t2s = criterion_GAN(pred_fake, labels_source.float())

        # cycle loss
        recovered_s = G_t2s(fake_t)
        loss_cycle_sts = criterion_cycle(recovered_s, features_source)

        recovered_t = G_s2t(fake_s)
        loss_cycle_tst = criterion_cycle(recovered_t, features_target)

        # sem loss
        pred_recovered_s = base_network.fc(recovered_s)
        pred_fake_t = base_network.fc(fake_t)
        loss_sem_t2s = criterion_Sem(pred_recovered_s, pred_fake_t)

        pred_recovered_t = base_network.fc(recovered_t)
        pred_fake_s = base_network.fc(fake_s)
        loss_sem_s2t = criterion_Sem(pred_recovered_t, pred_fake_s)

        loss_cycle = loss_cycle_tst + loss_cycle_sts
        weights = args.weight_in_lossG.split(',')
        loss_G = float(weights[0]) * (loss_identity_s + loss_identity_t) + \
                 float(weights[1]) * (loss_G_s2t + loss_G_t2s) + \
                 float(weights[2]) * loss_cycle + \
                 float(weights[3]) * (loss_sem_s2t + loss_sem_t2s)



        # 训练softmax分类器
        outputs_fake = classifier1(fake_t.detach())
        # 分类器优化
        classifier_loss1 = nn.CrossEntropyLoss()(outputs_fake, labels_source)
        classifier1_optim.zero_grad()
        classifier_loss1.backward()
        classifier1_optim.step()

        total_loss = loss_params["trade_off"] * transfer_loss + classifier_loss + args.cyc_loss_weight*loss_G
        total_loss.backward()
        optimizer.step()
        optimizer_G.step()

        ###### Discriminator S ######
        optimizer_D_s.zero_grad()

        # Real loss
        pred_real = D_s(features_source.detach())
        loss_D_real = criterion_GAN(pred_real, real_label)

        # Fake loss
        fake_s = fake_S_buffer.push_and_pop(fake_s)
        pred_fake = D_s(fake_s.detach())
        loss_D_fake = criterion_GAN(pred_fake, fake_label)

        # Total loss
        loss_D_s = loss_D_real + loss_D_fake
        loss_D_s.backward()

        optimizer_D_s.step()
        ###################################

        ###### Discriminator t ######
        optimizer_D_t.zero_grad()

        # Real loss
        pred_real = D_t(features_target.detach())
        loss_D_real = criterion_GAN(pred_real, real_label)

        # Fake loss
        fake_t = fake_T_buffer.push_and_pop(fake_t)
        pred_fake = D_t(fake_t.detach())
        loss_D_fake = criterion_GAN(pred_fake, fake_label)

        # Total loss
        loss_D_t = loss_D_real + loss_D_fake
        loss_D_t.backward()
        optimizer_D_t.step()
        print("it_train: {:05d} / {:05d} over".format(i, config["num_iterations"]))
    now = datetime.datetime.now()
    d = str(now.month)+'-'+str(now.day)+' '+str(now.hour)+':'+str(now.minute)+":"+str(now.second)
    torch.save(best_model, osp.join(config["output_path"],
                                    "{}_to_{}_best_model_acc-{}_{}.pth.tar".format(args.source, args.target,
                                                                            best_acc,d)))
    return best_acc
Exemplo n.º 5
0
def train(config):
    ## set pre-process
    prep_dict = {}
    prep_config = config["prep"]
    prep_dict["source"] = prep.image_train(**config["prep"]['params'])
    prep_dict["target"] = prep.image_train(**config["prep"]['params'])
    if prep_config["test_10crop"]:
        prep_dict["test"] = prep.image_test_10crop(**config["prep"]['params'])
    else:
        prep_dict["test"] = prep.image_test(**config["prep"]['params'])

    ## prepare data
    dsets = {}
    dset_loaders = {}
    data_config = config["data"]
    train_bs = data_config["source"]["batch_size"]
    test_bs = data_config["test"]["batch_size"]
    dsets["source"] = ImageList(open(data_config["source"]["list_path"]).readlines(), \
                                transform=prep_dict["source"])
    dset_loaders["source"] = DataLoader(dsets["source"], batch_size=train_bs, \
                                        shuffle=True, num_workers=0, drop_last=True)
    dsets["target"] = ImageList(open(data_config["target"]["list_path"]).readlines(), \
                                transform=prep_dict["target"])
    dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, \
                                        shuffle=True, num_workers=0, drop_last=True)

    if prep_config["test_10crop"]:
        for i in range(10):
            dsets["test"] = [ImageList(open(data_config["test"]["list_path"]).readlines(), \
                                       transform=prep_dict["test"][i]) for i in range(10)]
            dset_loaders["test"] = [DataLoader(dset, batch_size=test_bs, \
                                               shuffle=False, num_workers=0) for dset in dsets['test']]
    else:
        dsets["test"] = ImageList(open(data_config["test"]["list_path"]).readlines(), \
                                  transform=prep_dict["test"])
        dset_loaders["test"] = DataLoader(dsets["test"], batch_size=test_bs, \
                                          shuffle=False, num_workers=0)

    # class_num = config["network"]["params"]["class_num"]
    n_class = config["network"]["params"]["class_num"]

    ## set base network
    net_config = config["network"]
    base_network_stu = net_config["name"](**net_config["params"]).cuda()
    base_network_tea = net_config["name"](**net_config["params"]).cuda()

    ## add additional network for some methods
    if config["loss"]["random"]:
        random_layer = network.RandomLayer([base_network_stu.output_num(), n_class], config["loss"]["random_dim"])
        ad_net = network.AdversarialNetwork(config["loss"]["random_dim"], 1024)
    else:
        random_layer = None

        if config['method'] == 'DANN':
            ad_net = network.AdversarialNetwork(base_network_stu.output_num(), 1024)#DANN
        else:
            ad_net = network.AdversarialNetwork(base_network_stu.output_num() * n_class, 1024)

    if config["loss"]["random"]:
        random_layer.cuda()
    ad_net = ad_net.cuda()
    ad_net2 = network.AdversarialNetwork(n_class, n_class*4)
    ad_net2.cuda()

    parameter_list = base_network_stu.get_parameters() + ad_net.get_parameters()

    teacher_params = list(base_network_tea.parameters())
    for param in teacher_params:
        param.requires_grad = False

    ## set optimizer
    optimizer_config = config["optimizer"]
    optimizer = optimizer_config["type"](parameter_list, \
                                         **(optimizer_config["optim_params"]))

    teacher_optimizer = EMAWeightOptimizer(base_network_tea, base_network_stu, alpha=0.99)

    param_lr = []
    for param_group in optimizer.param_groups:
        param_lr.append(param_group["lr"])
    schedule_param = optimizer_config["lr_param"]
    lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]]

    ## train
    len_train_source = len(dset_loaders["source"])
    len_train_target = len(dset_loaders["target"])
    best_acc = 0.0

    output1(log_name)

    loss1, loss2, loss3, loss4, loss5,loss6 = 0, 0, 0, 0, 0,0

    output1('    =======    DA TRAINING    =======    ')

    best1 = 0
    f_t_result = []
    max_iter = config["num_iterations"]
    for i in range(max_iter+1):
        if i % config["test_interval"] == config["test_interval"] - 1 and i > 1500:
            base_network_tea.train(False)
            base_network_stu.train(False)

            # print("test")
            if 'MT' in config['method']:
                temp_acc = image_classification_test(dset_loaders,
                                                     base_network_tea, test_10crop=prep_config["test_10crop"])
                if temp_acc > best_acc:
                    best_acc = temp_acc

                log_str = "iter: {:05d}, tea_precision: {:.5f}".format(i, temp_acc)
                output1(log_str)
                #
                # if i > 20001 and best_acc < 0.69:
                #     break
                #
                # if temp_acc < 0.67:
                #     break
                #
                # if i > 30001 and best_acc < 0.71:
                #     break
                    # torch.save(base_network_tea, osp.join(path, "_model.pth.tar"))

                # temp_acc = image_classification_test(dset_loaders,
                #                                      base_network_stu, test_10crop=prep_config["test_10crop"])
                if temp_acc > best_acc:
                    best_acc = temp_acc
                    torch.save(base_network_stu, osp.join(path, "_model.pth.tar"))
                # log_str = "iter: {:05d}, stu_precision: {:.5f}".format(i, temp_acc)
                # output1(log_str)
            else:
                temp_acc = image_classification_test(dset_loaders,
                                                     base_network_stu, test_10crop=prep_config["test_10crop"])
                if temp_acc > best_acc:
                    best_acc = temp_acc
                    torch.save(base_network_stu, osp.join(path,"_model.pth.tar"))
                log_str = "iter: {:05d}, precision: {:.5f}".format(i, temp_acc)
                output1(log_str)

        loss_params = config["loss"]
        ## train one iter
        base_network_stu.train(True)
        base_network_tea.train(True)

        ad_net.train(True)
        optimizer = lr_scheduler(optimizer, i, **schedule_param)
        optimizer.zero_grad()

        if i % len_train_target == 0:
            iter_target = iter(dset_loaders["target"])
        if i % len_train_source == 0:
            iter_source = iter(dset_loaders["source"])

        inputs_source, labels_source = iter_source.next()
        inputs_target, labels_target = iter_target.next()

        inputs_source, inputs_target, labels_source = inputs_source.cuda(), inputs_target.cuda(), labels_source.cuda()
        features_source, outputs_source = base_network_stu(inputs_source)

        features_target_stu, outputs_target_stu = base_network_stu(inputs_target)
        features_target_tea, outputs_target_tea = base_network_tea(inputs_target)

        softmax_out_source = nn.Softmax(dim=1)(outputs_source)
        softmax_out_target_stu = nn.Softmax(dim=1)(outputs_target_stu)
        softmax_out_target_tea = nn.Softmax(dim=1)(outputs_target_tea)

        features = torch.cat((features_source, features_target_stu), dim=0)

        if 'MT' in config['method']:
            softmax_out = torch.cat((softmax_out_source, softmax_out_target_tea), dim=0)

        else:
            softmax_out = torch.cat((softmax_out_source, softmax_out_target_stu), dim=0)

        vat_loss = VAT(base_network_stu).cuda()

        n, d = features_source.shape
        decay = cal_decay(start=1,end=0.6,i = i)
        # image number in each class
        s_labels = labels_source
        t_max, t_labels = torch.max(softmax_out_target_tea, 1)
        t_max, t_labels = t_max.cuda(), t_labels.cuda()
        if config['method'] == 'DANN+dis' or config['method'] == 'CDRL':
            pass

        elif config['method'] == 'RESNET':
            classifier_loss = nn.CrossEntropyLoss()(outputs_source, labels_source)
            total_loss = classifier_loss

        elif config['method'] == 'CDAN+E':
            entropy = losssntg.Entropy(softmax_out)
            ad_loss = losssntg.CDANori([features, softmax_out], ad_net, entropy, network.calc_coeff(i),
                                          random_layer)

            transfer_loss = loss_params["trade_off"] * ad_loss
            classifier_loss = nn.CrossEntropyLoss()(outputs_source, labels_source)
            total_loss = transfer_loss + classifier_loss


        elif config['method'] == 'CDAN':
            ad_loss = losssntg.CDANori([features, softmax_out], ad_net, None, None, random_layer)

            transfer_loss = loss_params["trade_off"] * ad_loss
            classifier_loss = nn.CrossEntropyLoss()(outputs_source, labels_source)
            total_loss = transfer_loss + classifier_loss


        elif config['method'] == 'DANN':

            ad_loss = losssntg.DANN(features, ad_net)
            transfer_loss = loss_params["trade_off"] * ad_loss
            classifier_loss = nn.CrossEntropyLoss()(outputs_source, labels_source)
            total_loss = transfer_loss + classifier_loss


        elif config['method'] == 'CDAN+MT':
            th = config['th']

            ad_loss = losssntg.CDANori([features, softmax_out], ad_net, None, None, random_layer)
            unsup_loss = compute_aug_loss(softmax_out_target_stu, softmax_out_target_tea, n_class,confidence_thresh=th)
            unsup_loss = compute_aug_loss2(softmax_out_target_stu, softmax_out_target_tea, n_class,confidence_thresh=th)

            transfer_loss = loss_params["trade_off"] * ad_loss \
                            + 0.01 * unsup_loss
            classifier_loss = nn.CrossEntropyLoss()(outputs_source, labels_source)
            total_loss = transfer_loss + classifier_loss

        elif config['method'] == 'CDAN+MT+VAT':
            cent = ConditionalEntropyLoss().cuda()
            ad_loss = losssntg.CDANori([features, softmax_out], ad_net, None, None, random_layer)
            unsup_loss = compute_aug_loss(softmax_out_target_stu, softmax_out_target_tea, n_class)
            loss_trg_cent = 1e-2 * cent(outputs_target_stu)
            loss_trg_vat = 1e-2 * vat_loss(inputs_target, outputs_target_stu)
            transfer_loss = loss_params["trade_off"] * ad_loss \
                            + 0.001*(unsup_loss + loss_trg_cent + loss_trg_vat)
            classifier_loss = nn.CrossEntropyLoss()(outputs_source, labels_source)
            total_loss = transfer_loss + classifier_loss

        elif config['method'] == 'CDAN+MT+cent+VAT+temp':
            th = 0.7
            ad_loss = losssntg.CDANori([features, softmax_out], ad_net, None, None, random_layer)
            unsup_loss = compute_aug_loss(softmax_out_target_stu, softmax_out_target_tea, n_class,confidence_thresh=th)
            cent = ConditionalEntropyLoss().cuda()
            # loss_src_vat = vat_loss(inputs_source, outputs_source)
            loss_trg_cent = 1e-2 * cent(outputs_target_stu)
            loss_trg_vat = 1e-2 * vat_loss(inputs_target, outputs_target_stu)
            transfer_loss = loss_params["trade_off"] * ad_loss \
                            + unsup_loss + loss_trg_cent + loss_trg_vat
            # temperature
            classifier_loss = nn.NLLLoss()(nn.LogSoftmax(1)(outputs_source / 1.05), labels_source)
            total_loss = transfer_loss + classifier_loss

        elif config['method'] == 'CDAN+MT+cent+VAT+weightCross+T':

            if i % len_train_target == 0:

                if i != 0:
                    # print(cnt)
                    cnt = torch.tensor(cnt).float()
                    weight = cnt.sum() - cnt
                    weight = weight.cuda()
                else:
                    weight = torch.ones(n_class).cuda()

                cnt = [0] * n_class


            for j in t_labels:
                cnt[j.item()] += 1


            a = config['a']
            b = config['b']
            th = config['th']
            temp = config['temp']

            ad_loss = losssntg.CDANori([features, softmax_out], ad_net, None, None, random_layer)
            unsup_loss = compute_aug_loss2(softmax_out_target_stu, softmax_out_target_tea, n_class,confidence_thresh=th)
            # cbloss = compute_cbloss(softmax_out_target_stu, n_class, cls_balance=0.05)
            # unsup_loss = compute_aug_loss(softmax_out_target_stu, softmax_out_target_tea, n_class,confidence_thresh=th)

            cent = ConditionalEntropyLoss().cuda()

            # loss_src_vat = vat_loss(inputs_source, outputs_source)
            loss_trg_cent = 1e-2 * cent(outputs_target_stu)
            loss_trg_vat = 1e-2 * vat_loss(inputs_target, outputs_target_stu)
            classifier_loss = nn.CrossEntropyLoss(weight=weight)(outputs_source/temp, labels_source)
            # classifier_loss = nn.CrossEntropyLoss()(outputs_source, labels_source)
            transfer_loss = loss_params["trade_off"] * ad_loss \
                + a*unsup_loss + b*(loss_trg_vat+loss_trg_cent)

            total_loss = transfer_loss + classifier_loss

        elif config['method'] == 'CDAN+MT+E+VAT+weightCross+T':
            entropy = losssntg.Entropy(softmax_out)
            if i % len_train_target == 0:

                if i != 0:
                    # print(cnt)
                    cnt = torch.tensor(cnt).float()
                    weight = cnt.sum() - cnt
                    weight = weight.cuda()
                else:
                    weight = torch.ones(n_class).cuda()

                cnt = [0] * n_class

            for j in t_labels:
                cnt[j.item()] += 1

            a = config['a']
            b = config['b']
            th = config['th']
            temp = config['temp']

            ad_loss = losssntg.CDANori([features, softmax_out], ad_net, entropy, network.calc_coeff(i),
                                       random_layer)            # unsup_loss = compute_aug_loss2(softmax_out_target_stu, softmax_out_target_tea, n_class,confidence_thresh=th)
            # cbloss = compute_cbloss(softmax_out_target_stu, n_class, cls_balance=0.05)
            unsup_loss = compute_aug_loss(softmax_out_target_stu, softmax_out_target_tea, n_class, confidence_thresh=th)

            cent = ConditionalEntropyLoss().cuda()

            # loss_src_vat = vat_loss(inputs_source, outputs_source)
            loss_trg_cent = 1e-2 * cent(outputs_target_stu)
            loss_trg_vat = 1e-2 * vat_loss(inputs_target, outputs_target_stu)
            classifier_loss = nn.CrossEntropyLoss(weight=weight)(outputs_source / temp, labels_source)
            # classifier_loss = nn.CrossEntropyLoss()(outputs_source, labels_source)
            transfer_loss = loss_params["trade_off"] * ad_loss \
                            + a * unsup_loss + b * (loss_trg_vat + loss_trg_cent)

            total_loss = transfer_loss + classifier_loss

        # adout
        # ad_out1 = ad_net(features_source)
        # w = 1-ad_out1
        # c = w * nn.CrossEntropyLoss(reduction='none')(outputs_source, labels_source)
        # classifier_loss = c.mean()
        # total_loss = transfer_loss + classifier_loss
        total_loss.backward()
        optimizer.step()
        teacher_optimizer.step()

        loss1 += ad_loss.item()
        loss2 += classifier_loss.item()
        # loss3 += unsup_loss.item()
        # loss4 += loss_trg_cent.item()
        # loss5 += loss_trg_vat.item()
        # loss6 += cbloss.item()
        # dis_sloss_l += sloss_l.item()

        if i % 50 == 0 and i != 0:
            output1('iter:{:d}, ad_loss_D:{:.2f}, closs:{:.2f}, unsup_loss:{:.2f}, loss_trg_cent:{:.2f}, loss_trg_vatcd:{:.2f}, cbloss:{:.2f}'
                    .format(i, loss1, loss2, loss3, loss4, loss5,loss6))
            loss1, loss2, loss3, loss4, loss5, loss6 = 0, 0, 0, 0, 0, 0

    # torch.save(best_model, osp.join(path, "best_model.pth.tar"))
    return best_acc
Exemplo n.º 6
0
def main():
    # Training settings
    def str2bool(v):
        if v.lower() in ('yes', 'true', 't', 'y', '1'):
            return True
        elif v.lower() in ('no', 'false', 'f', 'n', '0'):
            return False
        else:
            raise argparse.ArgumentTypeError('Unsupported value encountered.')

    parser = argparse.ArgumentParser(description='ALDA USPS2MNIST')
    parser.add_argument('method',
                        type=str,
                        default='ALDA',
                        choices=['DANN', "ALDA"])
    parser.add_argument('--task', default='MNIST2USPS', help='task to perform')
    parser.add_argument('--batch_size',
                        type=int,
                        default=64,
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test_batch_size',
                        type=int,
                        default=1000,
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=10,
                        metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--lr',
                        type=float,
                        default=2e-4,
                        metavar='LR',
                        help='learning rate (default: 2e-4)')
    parser.add_argument('--gpu_id', type=str, help='cuda device id')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument(
        '--log_interval',
        type=int,
        default=500,
        help='how many batches to wait before logging training status')
    parser.add_argument('--trade_off',
                        type=float,
                        default=1.0,
                        help="trade_off")
    parser.add_argument('--start_epoch',
                        type=int,
                        default=0,
                        help="begin adaptation after start_epoch")
    parser.add_argument('--threshold',
                        default=0.9,
                        type=float,
                        help="threshold of pseudo labels")
    parser.add_argument(
        '--output_dir',
        type=str,
        default=None,
        help="output directory of our model (in ../snapshot directory)")
    parser.add_argument('--loss_type',
                        type=str,
                        default='all',
                        help="whether add reg_loss or correct_loss.")
    parser.add_argument('--cos_dist',
                        type=str2bool,
                        default=False,
                        help="the classifier uses cosine similarity.")
    parser.add_argument('--num_worker', type=int, default=4)
    args = parser.parse_args()

    torch.manual_seed(args.seed)
    network.set_device(args.gpu_id)

    if args.task == 'USPS2MNIST':
        source_list = './data/usps2mnist/usps_train.txt'
        target_list = './data/usps2mnist/mnist_train.txt'
        test_list = './data/usps2mnist/mnist_test.txt'
        start_epoch = 1
        decay_epoch = 6
    elif args.task == 'MNIST2USPS':
        source_list = './data/usps2mnist/mnist_train.txt'
        target_list = './data/usps2mnist/usps_train.txt'
        test_list = './data/usps2mnist/usps_test.txt'
        start_epoch = 1
        decay_epoch = 5
    else:
        raise Exception('task cannot be recognized!')

    source_list = open(source_list).readlines()
    target_list = open(target_list).readlines()
    test_list = open(test_list).readlines()

    train_loader = torch.utils.data.DataLoader(ImageList(
        source_list,
        transform=transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, ), (0.5, ))
        ]),
        mode='L'),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.num_worker,
                                               drop_last=True,
                                               pin_memory=True)
    train_loader1 = torch.utils.data.DataLoader(ImageList(
        target_list,
        transform=transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, ), (0.5, ))
        ]),
        mode='L'),
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                num_workers=args.num_worker,
                                                drop_last=True,
                                                pin_memory=True)
    test_loader = torch.utils.data.DataLoader(ImageList(
        test_list,
        transform=transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, ), (0.5, ))
        ]),
        mode='L'),
                                              batch_size=args.test_batch_size,
                                              shuffle=False,
                                              num_workers=args.num_worker,
                                              pin_memory=True)

    model = network.USPS_EnsembNet()
    model = model.to(network.dev)
    class_num = 10

    random_layer = None
    if args.method == "ALDA":
        ad_net = network.Multi_AdversarialNetwork(model.output_num(), 500,
                                                  class_num)
    elif args.method == "DANN":
        ad_net = network.AdversarialNetwork(model.output_num(), 500)
    ad_net = ad_net.to(network.dev)
    if args.task == 'USPS2MNIST':
        args.lr = 2e-4
    else:
        args.lr = 1e-3
    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=0.0005)
    optimizer_ad = optim.Adam(ad_net.parameters(),
                              lr=args.lr,
                              weight_decay=0.0005)

    start_epoch = args.start_epoch
    if args.output_dir is None:
        args.output_dir = args.task.lower() + '_' + args.method
    output_path = "snapshot/" + args.output_dir
    if os.path.exists(output_path):
        print("checkpoint dir exists, which will be removed")
        import shutil
        shutil.rmtree(output_path, ignore_errors=True)
    os.mkdir(output_path)

    for epoch in range(1, args.epochs + 1):
        if epoch % decay_epoch == 0:
            for param_group in optimizer.param_groups:
                param_group["lr"] = param_group["lr"] * 0.5
        train(args, model, ad_net, train_loader, train_loader1, optimizer,
              optimizer_ad, epoch, start_epoch, args.method)
        test(args, model, test_loader)
        if epoch % 5 == 1:
            torch.save(model.state_dict(),
                       osp.join(output_path, "epoch_{}.pth".format(epoch)))
Exemplo n.º 7
0
def train(config):
    ####################################################
    # Tensorboard setting
    ####################################################
    #tensor_writer = SummaryWriter(config["tensorboard_path"])

    ####################################################
    # Data setting
    ####################################################

    prep_dict = {}  # 데이터 전처리 transforms 부분
    prep_dict["source"] = prep.image_train(**config['prep']['params'])
    prep_dict["target"] = prep.image_train(**config["prep"]['params'])
    prep_dict["test"] = prep.image_test(**config['prep']['params'])

    dsets = {}
    dsets["source"] = datasets.ImageFolder(config['s_dset_path'],
                                           transform=prep_dict["source"])
    dsets["target"] = datasets.ImageFolder(config['t_dset_path'],
                                           transform=prep_dict['target'])
    dsets['test'] = datasets.ImageFolder(config['t_dset_path'],
                                         transform=prep_dict['test'])

    data_config = config["data"]
    train_source_bs = data_config["source"][
        "batch_size"]  #원본은 source와 target 모두 source train bs로 설정되었는데 이를 수정함
    train_target_bs = data_config['target']['batch_size']
    test_bs = data_config["test"]["batch_size"]

    dset_loaders = {}
    dset_loaders["source"] = DataLoader(
        dsets["source"],
        batch_size=train_source_bs,
        shuffle=True,
        num_workers=4,
        drop_last=True
    )  # 원본은 drop_last=True, 이렇게 해야 마지막까지 source, target에서 동일한 수로 배치 생성가능
    dset_loaders["target"] = DataLoader(dsets["target"],
                                        batch_size=train_target_bs,
                                        shuffle=True,
                                        num_workers=4,
                                        drop_last=True)
    dset_loaders['test'] = DataLoader(dsets['test'],
                                      batch_size=test_bs,
                                      shuffle=False,
                                      num_workers=4,
                                      drop_last=False)

    ####################################################
    # Network Setting
    ####################################################

    class_num = config["network"]['params']['class_num']

    net_config = config["network"]
    """
        config['network'] = {'name': network.ResNetFc,
                         'params': {'resnet_name': args.net,
                                    'use_bottleneck': True,
                                    'bottleneck_dim': 256,
                                    'new_cls': True,
                                    'class_num': args.class_num,
                                    'type' : args.type}
                         }
    """

    base_network = net_config["name"](**net_config["params"])
    #network.py에 정의된 ResNetFc() 클래스 호출
    base_network = base_network.cuda()  # ResNetFc(Resnet, True, 256, True, 12)

    if config["loss"]["random"]:
        random_layer = network.RandomLayer(
            [base_network.output_num(), class_num],
            config["loss"]["random_dim"])
        random_layer.cuda()
        ad_net = network.AdversarialNetwork(config["loss"]["random_dim"], 1024)
    else:
        random_layer = None
        ad_net = network.AdversarialNetwork(
            base_network.output_num() * class_num, 1024)  # 왜 class 수 만큼 곱하지?

    ad_net = ad_net.cuda()

    parameter_list = base_network.get_parameters() + ad_net.get_parameters()

    ####################################################
    # Env Setting
    ####################################################

    #gpus = config['gpu'].split(',')
    #if len(gpus) > 1 :
    #ad_net = nn.DataParallel(ad_net, device_ids=[int(i) for i in gpus])
    #base_network = nn.DataParallel(base_network, device_ids=[int(i) for i in gpus])

    ####################################################
    # Optimizer Setting
    ####################################################

    optimizer_config = config['optimizer']
    optimizer = optimizer_config["type"](parameter_list,
                                         **(optimizer_config["optim_params"]))
    # optim.SGD

    #config['optimizer'] = {'type': optim.SGD,
    #'optim_params': {'lr': args.lr,
    #'momentum': 0.9,
    #'weight_decay': 0.0005,
    #'nestrov': True},
    #'lr_type': "inv",
    #'lr_param': {"lr": args.lr,
    #'gamma': 0.001, # 이거 0.01이여야 하지 않나?
    #'power': 0.75
    #}

    param_lr = []
    for param_group in optimizer.param_groups:
        param_lr.append(param_group['lr'])

    schedule_param = optimizer_config['lr_param']

    lr_scheduler = lr_schedule.schedule_dict[
        optimizer_config["lr_type"]]  # return optimizer

    ####################################################
    # Train
    ####################################################

    len_train_source = len(dset_loaders["source"])
    len_train_target = len(dset_loaders["target"])

    transfer_loss_value = 0.0
    classifier_loss_value = 0.0
    total_loss_value = 0.0

    best_acc = 0.0

    batch_size = config["data"]["source"]["batch_size"]

    for i in range(
            config["num_iterations"]):  # num_iterations수의 batch가 학습에 사용됨
        sys.stdout.write("Iteration : {} \r".format(i))
        sys.stdout.flush()

        loss_params = config["loss"]

        base_network.train(True)
        ad_net.train(True)

        optimizer = lr_scheduler(optimizer, i, **schedule_param)
        optimizer.zero_grad()

        if i % len_train_source == 0:
            iter_source = iter(dset_loaders["source"])
        if i % len_train_target == 0:
            iter_target = iter(dset_loaders["target"])

        inputs_source, labels_source = iter_source.next()
        inputs_target, labels_target = iter_target.next()

        inputs_source, labels_source = inputs_source.cuda(
        ), labels_source.cuda()
        inputs_target = inputs_target.cuda()

        inputs = torch.cat((inputs_source, inputs_target), dim=0)

        features, outputs, tau, cur_mean_source, cur_mean_target, output_mean_source, output_mean_target = base_network(
            inputs)

        softmax_out = nn.Softmax(dim=1)(outputs)

        outputs_source = outputs[:batch_size]
        outputs_target = outputs[batch_size:]

        if config['method'] == 'CDAN+E' or config['method'] == 'CDAN_TransNorm':
            entropy = loss.Entropy(softmax_out)
            transfer_loss = loss.CDAN([features, softmax_out], ad_net, entropy,
                                      network.calc_coeff(i), random_layer)
        elif config['method'] == 'CDAN':
            transfer_loss = loss.CDAN([features, softmax_out], ad_net, None,
                                      None, random_layer)
        elif config['method'] == 'DANN':
            pass  # 나중에 정리하기
        else:
            raise ValueError('Method cannot be recognized')

        classifier_loss = nn.CrossEntropyLoss()(outputs_source, labels_source)
        total_loss = loss_params["trade_off"] * transfer_loss + classifier_loss

        total_loss.backward()
        optimizer.step()

        #tensor_writer.add_scalar('total_loss', total_loss.i )
        #tensor_writer.add_scalar('classifier_loss', classifier_loss, i)
        #tensor_writer.add_scalar('transfer_loss', transfer_loss, i)

        ####################################################
        # Test
        ####################################################
        if i % config["test_interval"] == config["test_interval"] - 1:
            # test interval 마다
            base_network.train(False)
            temp_acc = image_classification_test(dset_loaders, base_network)
            temp_model = nn.Sequential(base_network)
            if temp_acc > best_acc:
                best_acc = temp_acc
                best_model = temp_model
                ACC = round(best_acc, 2) * 100
                torch.save(
                    best_model,
                    os.path.join(config["output_path"],
                                 "iter_{}_model.pth.tar".format(ACC)))
            log_str = "iter: {:05d}, precision: {:.5f}".format(i, temp_acc)
            config["out_file"].write(log_str + "\n")
            config["out_file"].flush()
            print(log_str)
Exemplo n.º 8
0
def train_init(args):
    # prepare data
    dsets = {}
    dset_loaders = {}
    dsets["source"] = ImageList(open(args.source_list).readlines(), \
                                transform=image_train())
    dset_loaders["source"] = DataLoader(dsets["source"], batch_size=args.batch_size, \
                                        shuffle=True, num_workers=4, drop_last=True)
    dsets["target"] = ImageList(open(args.target_list).readlines(), \
                                transform=image_train())
    dset_loaders["target"] = DataLoader(dsets["target"], batch_size=args.batch_size, \
                                        shuffle=True, num_workers=4, drop_last=True)

    dsets["test"] = ImageList(open(args.target_list).readlines(), \
                              transform=image_test())
    dset_loaders["test"] = DataLoader(dsets["test"], batch_size=2 * args.batch_size, \
                                      shuffle=False, num_workers=4)

    #model
    model = network.ResNet(class_num=args.num_class,
                           radius=args.radius,
                           trainable_radius=args.trainable_radius).cuda()
    adv_net = network.AdversarialNetwork(in_feature=model.output_num(),
                                         hidden_size=1024).cuda()
    parameter_list = model.get_parameters() + adv_net.get_parameters()
    optimizer = torch.optim.SGD(parameter_list,
                                lr=args.lr,
                                momentum=0.9,
                                weight_decay=0.005)

    gpus = args.gpu_id.split(',')
    if len(gpus) > 1:
        adv_net = nn.DataParallel(adv_net, device_ids=[int(i) for i in gpus])
        model = nn.DataParallel(model, device_ids=[int(i) for i in gpus])

    ## train
    len_train_source = len(dset_loaders["source"])
    len_train_target = len(dset_loaders["target"])
    best_acc = 0.0
    best_model = copy.deepcopy(model)

    Cs_memory = torch.zeros(args.num_class, 256).cuda()
    Ct_memory = torch.zeros(args.num_class, 256).cuda()

    for i in range(args.max_iter):
        if i % args.test_interval == args.test_interval - 1:
            model.train(False)
            temp_acc = image_classification_test(dset_loaders, model)
            if temp_acc > best_acc:
                best_acc = temp_acc
                best_model = copy.deepcopy(model)
            log_str = "\niter: {:05d}, \t precision: {:.4f},\t best_acc:{:.4f}".format(
                i, temp_acc, best_acc)
            args.log_file.write(log_str)
            args.log_file.flush()
            print(log_str)
        if i % args.snapshot_interval == args.snapshot_interval - 1:
            if not os.path.exists('snapshot'):
                os.mkdir('snapshot')
            if not os.path.exists('snapshot/save'):
                os.mkdir('snapshot/save')
            torch.save(best_model, 'snapshot/save/initial_model.pk')

        model.train(True)
        adv_net.train(True)
        if (args.lr_decay):
            optimizer = lr_schedule.inv_lr_scheduler(optimizer, i)

        if i % len_train_source == 0:
            iter_source = iter(dset_loaders["source"])
        if i % len_train_target == 0:
            iter_target = iter(dset_loaders["target"])
        inputs_source, labels_source = iter_source.next()
        inputs_target, labels_target = iter_target.next()
        inputs_source, inputs_target, labels_source = inputs_source.cuda(
        ), inputs_target.cuda(), labels_source.cuda()
        features_source, outputs_source = model(inputs_source)
        features_target, outputs_target = model(inputs_target)
        features = torch.cat((features_source, features_target), dim=0)

        classifier_loss = nn.CrossEntropyLoss()(outputs_source, labels_source)
        adv_loss = utils.loss_adv(features, adv_net)

        if args.baseline == 'MSTN':
            lam = network.calc_coeff(i)
        elif args.baseline == 'DANN':
            lam = 0.0
        pseu_labels_target = torch.argmax(outputs_target, dim=1)
        loss_sm, Cs_memory, Ct_memory = utils.SM(features_source,
                                                 features_target,
                                                 labels_source,
                                                 pseu_labels_target, Cs_memory,
                                                 Ct_memory)
        total_loss = classifier_loss + adv_loss + lam * loss_sm
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        print('step:{: d},\t,class_loss:{:.4f},\t,adv_loss:{:.4f}'.format(
            i, classifier_loss.item(), adv_loss.item()))
        Cs_memory.detach_()
        Ct_memory.detach_()

    return best_acc, best_model
Exemplo n.º 9
0
def train(config):
    ## set up summary writer
    writer = SummaryWriter(config['output_path'])
    class_num = config["network"]["params"]["class_num"]
    loss_params = config["loss"]

    class_criterion = nn.CrossEntropyLoss()
    transfer_criterion = loss.PADA
    center_criterion = loss_params["loss_type"](
        num_classes=class_num,
        feat_dim=config["network"]["params"]["bottleneck_dim"])

    ## prepare data
    dsets = {}
    dset_loaders = {}

    #sampling WOR, i guess we leave the 10 in the middle to validate?
    pristine_indices = torch.randperm(len(pristine_x))
    #train
    pristine_x_train = pristine_x[
        pristine_indices[:int(np.floor(.7 * len(pristine_x)))]]
    pristine_y_train = pristine_y[
        pristine_indices[:int(np.floor(.7 * len(pristine_x)))]]
    #validate --- gets passed into test functions in train file
    pristine_x_valid = pristine_x[pristine_indices[
        int(np.floor(.7 * len(pristine_x))):int(np.floor(.8 *
                                                         len(pristine_x)))]]
    pristine_y_valid = pristine_y[pristine_indices[
        int(np.floor(.7 * len(pristine_x))):int(np.floor(.8 *
                                                         len(pristine_x)))]]
    #test for evaluation file
    pristine_x_test = pristine_x[
        pristine_indices[int(np.floor(.8 * len(pristine_x))):]]
    pristine_y_test = pristine_y[
        pristine_indices[int(np.floor(.8 * len(pristine_x))):]]

    noisy_indices = torch.randperm(len(noisy_x))
    #train
    noisy_x_train = noisy_x[noisy_indices[:int(np.floor(.7 * len(noisy_x)))]]
    noisy_y_train = noisy_y[noisy_indices[:int(np.floor(.7 * len(noisy_x)))]]
    #validate --- gets passed into test functions in train file
    noisy_x_valid = noisy_x[noisy_indices[int(np.floor(.7 * len(noisy_x))
                                              ):int(np.floor(.8 *
                                                             len(noisy_x)))]]
    noisy_y_valid = noisy_y[noisy_indices[int(np.floor(.7 * len(noisy_x))
                                              ):int(np.floor(.8 *
                                                             len(noisy_x)))]]
    #test for evaluation file
    noisy_x_test = noisy_x[noisy_indices[int(np.floor(.8 * len(noisy_x))):]]
    noisy_y_test = noisy_y[noisy_indices[int(np.floor(.8 * len(noisy_x))):]]

    dsets["source"] = TensorDataset(pristine_x_train, pristine_y_train)
    dsets["target"] = TensorDataset(noisy_x_train, noisy_y_train)

    dsets["source_valid"] = TensorDataset(pristine_x_valid, pristine_y_valid)
    dsets["target_valid"] = TensorDataset(noisy_x_valid, noisy_y_valid)

    dsets["source_test"] = TensorDataset(pristine_x_test, pristine_y_test)
    dsets["target_test"] = TensorDataset(noisy_x_test, noisy_y_test)

    #put your dataloaders here
    #i stole batch size numbers from below
    dset_loaders["source"] = DataLoader(dsets["source"],
                                        batch_size=128,
                                        shuffle=True,
                                        num_workers=1)
    dset_loaders["target"] = DataLoader(dsets["target"],
                                        batch_size=128,
                                        shuffle=True,
                                        num_workers=1)

    #guessing batch size based on what was done for testing in the original file
    dset_loaders["source_valid"] = DataLoader(dsets["source_valid"],
                                              batch_size=64,
                                              shuffle=True,
                                              num_workers=1)
    dset_loaders["target_valid"] = DataLoader(dsets["target_valid"],
                                              batch_size=64,
                                              shuffle=True,
                                              num_workers=1)

    dset_loaders["source_test"] = DataLoader(dsets["source_test"],
                                             batch_size=64,
                                             shuffle=True,
                                             num_workers=1)
    dset_loaders["target_test"] = DataLoader(dsets["target_test"],
                                             batch_size=64,
                                             shuffle=True,
                                             num_workers=1)

    config['out_file'].write("dataset sizes: source={}, target={}\n".format(
        len(dsets["source"]), len(dsets["target"])))

    config["num_iterations"] = len(
        dset_loaders["source"]) * config["epochs"] + 1
    config["test_interval"] = len(dset_loaders["source"])
    config["snapshot_interval"] = len(
        dset_loaders["source"]) * config["epochs"] * .25
    config["log_iter"] = len(dset_loaders["source"])

    #print the configuration you are using
    config["out_file"].write("config: {}\n".format(config))
    config["out_file"].flush()

    # set up early stop
    early_stop_engine = EarlyStopping(config["early_stop_patience"])

    ## set base network
    net_config = config["network"]
    base_network = net_config["name"](**net_config["params"])

    use_gpu = torch.cuda.is_available()
    if use_gpu:
        base_network = base_network.cuda()

    ## add additional network for some methods
    ad_net = network.AdversarialNetwork(base_network.output_num())
    gradient_reverse_layer = network.AdversarialLayer(
        high_value=config["high"])  #,
    #max_iter_value=config["num_iterations"])
    if use_gpu:
        ad_net = ad_net.cuda()

        ## collect parameters
    if "DeepMerge" in args.net:
        parameter_list = [{
            "params": base_network.parameters(),
            "lr_mult": 1,
            'decay_mult': 2
        }]
        parameter_list.append({
            "params": ad_net.parameters(),
            "lr_mult": .1,
            'decay_mult': 2
        })
        parameter_list.append({
            "params": center_criterion.parameters(),
            "lr_mult": 10,
            'decay_mult': 1
        })
    elif "ResNet18" in args.net:
        parameter_list = [{
            "params": base_network.parameters(),
            "lr_mult": 1,
            'decay_mult': 2
        }]
        parameter_list.append({
            "params": ad_net.parameters(),
            "lr_mult": .1,
            'decay_mult': 2
        })
        parameter_list.append({
            "params": center_criterion.parameters(),
            "lr_mult": 10,
            'decay_mult': 1
        })

    if net_config["params"]["new_cls"]:
        if net_config["params"]["use_bottleneck"]:
            parameter_list = [{"params":base_network.feature_layers.parameters(), "lr_mult":1, 'decay_mult':2}, \
                            {"params":base_network.bottleneck.parameters(), "lr_mult":10, 'decay_mult':2}, \
                            {"params":base_network.fc.parameters(), "lr_mult":10, 'decay_mult':2}]
            parameter_list.append({
                "params": ad_net.parameters(),
                "lr_mult": config["ad_net_mult_lr"],
                'decay_mult': 2
            })
            parameter_list.append({
                "params": center_criterion.parameters(),
                "lr_mult": 10,
                'decay_mult': 1
            })
        else:
            parameter_list = [{"params":base_network.feature_layers.parameters(), "lr_mult":1, 'decay_mult':2}, \
                            {"params":base_network.fc.parameters(), "lr_mult":10, 'decay_mult':2}]
            parameter_list.append({
                "params": ad_net.parameters(),
                "lr_mult": config["ad_net_mult_lr"],
                'decay_mult': 2
            })
            parameter_list.append({
                "params": center_criterion.parameters(),
                "lr_mult": 10,
                'decay_mult': 1
            })
    else:
        parameter_list = [{
            "params": base_network.parameters(),
            "lr_mult": 1,
            'decay_mult': 2
        }]
        parameter_list.append({
            "params": ad_net.parameters(),
            "lr_mult": config["ad_net_mult_lr"],
            'decay_mult': 2
        })
        parameter_list.append({
            "params": center_criterion.parameters(),
            "lr_mult": 10,
            'decay_mult': 1
        })
    #Should I put lr_mult here as 1 for DeepMerge too? Probably!

    ## set optimizer
    optimizer_config = config["optimizer"]
    optimizer = optim_dict[optimizer_config["type"]](parameter_list, \
                    **(optimizer_config["optim_params"]))
    param_lr = []
    for param_group in optimizer.param_groups:
        param_lr.append(param_group["lr"])
    schedule_param = optimizer_config["lr_param"]
    lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]]

    scan_lr = []
    scan_loss = []

    ## train
    len_train_source = len(dset_loaders["source"])
    len_train_target = len(dset_loaders["target"])
    len_valid_source = len(dset_loaders["source_valid"])
    len_valid_target = len(dset_loaders["target_valid"])

    transfer_loss_value = classifier_loss_value = total_loss_value = 0.0
    best_acc = 0.0

    for i in range(config["num_iterations"]):
        if i % config["test_interval"] == 0:
            base_network.train(False)
            if config['loss']['ly_type'] == "cosine":
                temp_acc, _ = image_classification_test(dset_loaders, 'source_valid', \
                    base_network, gpu=use_gpu, verbose = False, save_where = None)
                train_acc, _ = image_classification_test(dset_loaders, 'source', \
                    base_network, gpu=use_gpu, verbose = False, save_where = None)
            elif config['loss']['ly_type'] == "euclidean":
                temp_acc, _ = distance_classification_test(dset_loaders, 'source_valid', \
                    base_network, center_criterion.centers.detach(), gpu=use_gpu, verbose = False, save_where = None)
                train_acc, _ = distance_classification_test(
                    dset_loaders,
                    'source',
                    base_network,
                    center_criterion.centers.detach(),
                    gpu=use_gpu,
                    verbose=False,
                    save_where=None)
            else:
                raise ValueError("no test method for cls loss: {}".format(
                    config['loss']['ly_type']))

            snapshot_obj = {
                'epoch': i / len(dset_loaders["source"]),
                "base_network": base_network.state_dict(),
                'valid accuracy': temp_acc,
                'train accuracy': train_acc,
            }

            if (i + 1) % config["snapshot_interval"] == 0:
                torch.save(
                    snapshot_obj,
                    osp.join(
                        config["output_path"], "epoch_{}_model.pth.tar".format(
                            i / len(dset_loaders["source"]))))

            if config["loss"]["loss_name"] != "laplacian" and config["loss"][
                    "ly_type"] == "euclidean":
                snapshot_obj['center_criterion'] = center_criterion.state_dict(
                )

            if temp_acc > best_acc:
                best_acc = temp_acc

                # save best model
                torch.save(
                    snapshot_obj,
                    osp.join(config["output_path"], "best_model.pth.tar"))

            log_str = "epoch: {}, {} validation accuracy: {:.5f}, {} training accuracy: {:.5f}\n".format(
                i / len(dset_loaders["source"]), config['loss']['ly_type'],
                temp_acc, config['loss']['ly_type'], train_acc)
            config["out_file"].write(log_str)
            config["out_file"].flush()
            writer.add_scalar("validation accuracy", temp_acc,
                              i / len(dset_loaders["source"]))
            writer.add_scalar("training accuracy", train_acc,
                              i / len(dset_loaders["source"]))

        ## train one iter
        base_network.train(True)

        if i % config["log_iter"] == 0:
            optimizer = lr_scheduler(param_lr, optimizer, i,
                                     config["log_iter"], config["frozen lr"],
                                     config["cycle_length"], **schedule_param)

        if config["optimizer"]["lr_type"] == "one-cycle":
            optimizer = lr_scheduler(param_lr, optimizer, i,
                                     config["log_iter"], config["frozen lr"],
                                     config["cycle_length"], **schedule_param)

        if config["optimizer"]["lr_type"] == "linear":
            optimizer = lr_scheduler(param_lr, optimizer, i,
                                     config["log_iter"], config["frozen lr"],
                                     config["cycle_length"], **schedule_param)

        optim = optimizer.state_dict()
        scan_lr.append(optim['param_groups'][0]['lr'])

        optimizer.zero_grad()

        try:
            inputs_source, labels_source = iter(dset_loaders["source"]).next()
            inputs_target, labels_target = iter(dset_loaders["target"]).next()
        except StopIteration:
            iter(dset_loaders["source"])
            iter(dset_loaders["target"])

        if use_gpu:
            inputs_source, inputs_target, labels_source = \
                Variable(inputs_source).cuda(), Variable(inputs_target).cuda(), \
                Variable(labels_source).cuda()
        else:
            inputs_source, inputs_target, labels_source = Variable(inputs_source), \
                Variable(inputs_target), Variable(labels_source)

        inputs = torch.cat((inputs_source, inputs_target), dim=0)
        source_batch_size = inputs_source.size(0)

        if config['loss']['ly_type'] == 'cosine':
            features, logits = base_network(inputs)
            source_logits = logits.narrow(0, 0, source_batch_size)
        elif config['loss']['ly_type'] == 'euclidean':
            features, _ = base_network(inputs)
            logits = -1.0 * loss.distance_to_centroids(
                features, center_criterion.centers.detach())
            source_logits = logits.narrow(0, 0, source_batch_size)

        ad_net.train(True)
        weight_ad = torch.ones(inputs.size(0))
        transfer_loss = transfer_criterion(features, ad_net, gradient_reverse_layer, \
                                            weight_ad, use_gpu)
        ad_out, _ = ad_net(features.detach())
        ad_acc, source_acc_ad, target_acc_ad = domain_cls_accuracy(ad_out)

        # source domain classification task loss
        classifier_loss = class_criterion(source_logits, labels_source.long())
        # fisher loss on labeled source domain

        if config["fisher_or_no"] == 'no':
            total_loss = loss_params["trade_off"] * transfer_loss \
            + classifier_loss

            scan_loss.append(total_loss.cpu().float().item())

            total_loss.backward()

            ######################################
            # Plot embeddings periodically.
            if args.blobs is not None and i / len(
                    dset_loaders["source"]) % 50 == 0:
                visualizePerformance(base_network,
                                     dset_loaders["source"],
                                     dset_loaders["target"],
                                     batch_size=128,
                                     domain_classifier=ad_net,
                                     num_of_samples=100,
                                     imgName='embedding_' +
                                     str(i / len(dset_loaders["source"])),
                                     save_dir=osp.join(config["output_path"],
                                                       "blobs"))
            ##########################################

            # if center_grad is not None:
            #     # clear mmc_loss
            #     center_criterion.centers.grad.zero_()
            #     # Manually assign centers gradients other than using autograd
            #     center_criterion.centers.backward(center_grad)

            optimizer.step()

            if i % config["log_iter"] == 0:

                if config['lr_scan'] != 'no':
                    if not osp.exists(
                            osp.join(config["output_path"],
                                     "learning_rate_scan")):
                        os.makedirs(
                            osp.join(config["output_path"],
                                     "learning_rate_scan"))

                    plot_learning_rate_scan(
                        scan_lr, scan_loss, i / len(dset_loaders["source"]),
                        osp.join(config["output_path"], "learning_rate_scan"))

                if config['grad_vis'] != 'no':
                    if not osp.exists(
                            osp.join(config["output_path"], "gradients")):
                        os.makedirs(
                            osp.join(config["output_path"], "gradients"))

                    plot_grad_flow(
                        osp.join(config["output_path"], "gradients"),
                        i / len(dset_loaders["source"]),
                        base_network.named_parameters())

                config['out_file'].write(
                    'epoch {}: train total loss={:0.4f}, train transfer loss={:0.4f}, train classifier loss={:0.4f},'
                    'train source+target domain accuracy={:0.4f}, train source domain accuracy={:0.4f}, train target domain accuracy={:0.4f}\n'
                    .format(
                        i / len(dset_loaders["source"]),
                        total_loss.data.cpu().float().item(),
                        transfer_loss.data.cpu().float().item(),
                        classifier_loss.data.cpu().float().item(),
                        ad_acc,
                        source_acc_ad,
                        target_acc_ad,
                    ))
                config['out_file'].flush()
                writer.add_scalar("training total loss",
                                  total_loss.data.cpu().float().item(),
                                  i / len(dset_loaders["source"]))
                writer.add_scalar("training classifier loss",
                                  classifier_loss.data.cpu().float().item(),
                                  i / len(dset_loaders["source"]))
                writer.add_scalar("training transfer loss",
                                  transfer_loss.data.cpu().float().item(),
                                  i / len(dset_loaders["source"]))
                writer.add_scalar("training source+target domain accuracy",
                                  ad_acc, i / len(dset_loaders["source"]))
                writer.add_scalar("training source domain accuracy",
                                  source_acc_ad,
                                  i / len(dset_loaders["source"]))
                writer.add_scalar("training target domain accuracy",
                                  target_acc_ad,
                                  i / len(dset_loaders["source"]))

                #attempted validation step
                for j in range(0, len(dset_loaders["source_valid"])):
                    base_network.train(False)
                    with torch.no_grad():

                        try:
                            inputs_valid_source, labels_valid_source = iter(
                                dset_loaders["source_valid"]).next()
                            inputs_valid_target, labels_valid_target = iter(
                                dset_loaders["target_valid"]).next()
                        except StopIteration:
                            iter(dset_loaders["source_valid"])
                            iter(dset_loaders["target_valid"])

                        if use_gpu:
                            inputs_valid_source, inputs_valid_target, labels_valid_source = \
                                Variable(inputs_valid_source).cuda(), Variable(inputs_valid_target).cuda(), \
                                Variable(labels_valid_source).cuda()
                        else:
                            inputs_valid_source, inputs_valid_target, labels_valid_source = Variable(inputs_valid_source), \
                                Variable(inputs_valid_target), Variable(labels_valid_source)

                        valid_inputs = torch.cat(
                            (inputs_valid_source, inputs_valid_target), dim=0)
                        valid_source_batch_size = inputs_valid_source.size(0)

                        if config['loss']['ly_type'] == 'cosine':
                            features, logits = base_network(valid_inputs)
                            source_logits = logits.narrow(
                                0, 0, valid_source_batch_size)
                        elif config['loss']['ly_type'] == 'euclidean':
                            features, _ = base_network(valid_inputs)
                            logits = -1.0 * loss.distance_to_centroids(
                                features, center_criterion.centers.detach())
                            source_logits = logits.narrow(
                                0, 0, valid_source_batch_size)

                        ad_net.train(False)
                        weight_ad = torch.ones(valid_inputs.size(0))
                        transfer_loss = transfer_criterion(features, ad_net, gradient_reverse_layer, \
                                                           weight_ad, use_gpu)
                        ad_out, _ = ad_net(features.detach())
                        ad_acc, source_acc_ad, target_acc_ad = domain_cls_accuracy(
                            ad_out)

                        # source domain classification task loss
                        classifier_loss = class_criterion(
                            source_logits, labels_valid_source.long())

                        #if config["fisher_or_no"] == 'no':
                        total_loss = loss_params["trade_off"] * transfer_loss \
                                    + classifier_loss

                    if j % len(dset_loaders["source_valid"]) == 0:
                        config['out_file'].write(
                            'epoch {}: valid total loss={:0.4f}, valid transfer loss={:0.4f}, valid classifier loss={:0.4f},'
                            'valid source+target domain accuracy={:0.4f}, valid source domain accuracy={:0.4f}, valid target domain accuracy={:0.4f}\n'
                            .format(
                                i / len(dset_loaders["source"]),
                                total_loss.data.cpu().float().item(),
                                transfer_loss.data.cpu().float().item(),
                                classifier_loss.data.cpu().float().item(),
                                ad_acc,
                                source_acc_ad,
                                target_acc_ad,
                            ))
                        config['out_file'].flush()
                        writer.add_scalar("validation total loss",
                                          total_loss.data.cpu().float().item(),
                                          i / len(dset_loaders["source"]))
                        writer.add_scalar(
                            "validation classifier loss",
                            classifier_loss.data.cpu().float().item(),
                            i / len(dset_loaders["source"]))
                        writer.add_scalar(
                            "validation transfer loss",
                            transfer_loss.data.cpu().float().item(),
                            i / len(dset_loaders["source"]))
                        writer.add_scalar(
                            "validation source+target domain accuracy", ad_acc,
                            i / len(dset_loaders["source"]))
                        writer.add_scalar("validation source domain accuracy",
                                          source_acc_ad,
                                          i / len(dset_loaders["source"]))
                        writer.add_scalar("validation target domain accuracy",
                                          target_acc_ad,
                                          i / len(dset_loaders["source"]))

                        if early_stop_engine.is_stop_training(
                                classifier_loss.cpu().float().item()):
                            config["out_file"].write(
                                "no improvement after {}, stop training at step {}\n"
                                .format(config["early_stop_patience"],
                                        i / len(dset_loaders["source"])))

                            sys.exit()

        else:
            fisher_loss, fisher_intra_loss, fisher_inter_loss, center_grad = center_criterion(
                features.narrow(0, 0, int(inputs.size(0) / 2)),
                labels_source,
                inter_class=config["loss"]["inter_type"],
                intra_loss_weight=loss_params["intra_loss_coef"],
                inter_loss_weight=loss_params["inter_loss_coef"])
            # entropy minimization loss
            em_loss = loss.EntropyLoss(nn.Softmax(dim=1)(logits))

            # final loss
            total_loss = loss_params["trade_off"] * transfer_loss \
                         + fisher_loss \
                         + loss_params["em_loss_coef"] * em_loss \
                         + classifier_loss

            scan_loss.append(total_loss.cpu().float().item())

            total_loss.backward()

            ######################################
            # Plot embeddings periodically.
            if args.blobs is not None and i / len(
                    dset_loaders["source"]) % 50 == 0:
                visualizePerformance(base_network,
                                     dset_loaders["source"],
                                     dset_loaders["target"],
                                     batch_size=128,
                                     num_of_samples=50,
                                     imgName='embedding_' +
                                     str(i / len(dset_loaders["source"])),
                                     save_dir=osp.join(config["output_path"],
                                                       "blobs"))
            ##########################################

            if center_grad is not None:
                # clear mmc_loss
                center_criterion.centers.grad.zero_()
                # Manually assign centers gradients other than using autograd
                center_criterion.centers.backward(center_grad)

            optimizer.step()

            if i % config["log_iter"] == 0:

                if config['lr_scan'] != 'no':
                    if not osp.exists(
                            osp.join(config["output_path"],
                                     "learning_rate_scan")):
                        os.makedirs(
                            osp.join(config["output_path"],
                                     "learning_rate_scan"))

                    plot_learning_rate_scan(
                        scan_lr, scan_loss, i / len(dset_loaders["source"]),
                        osp.join(config["output_path"], "learning_rate_scan"))

                if config['grad_vis'] != 'no':
                    if not osp.exists(
                            osp.join(config["output_path"], "gradients")):
                        os.makedirs(
                            osp.join(config["output_path"], "gradients"))

                    plot_grad_flow(
                        osp.join(config["output_path"], "gradients"),
                        i / len(dset_loaders["source"]),
                        base_network.named_parameters())

                config['out_file'].write(
                    'epoch {}: train total loss={:0.4f}, train transfer loss={:0.4f}, train classifier loss={:0.4f}, '
                    'train entropy min loss={:0.4f}, '
                    'train fisher loss={:0.4f}, train intra-group fisher loss={:0.4f}, train inter-group fisher loss={:0.4f}, '
                    'train source+target domain accuracy={:0.4f}, train source domain accuracy={:0.4f}, train target domain accuracy={:0.4f}\n'
                    .format(
                        i / len(dset_loaders["source"]),
                        total_loss.data.cpu().float().item(),
                        transfer_loss.data.cpu().float().item(),
                        classifier_loss.data.cpu().float().item(),
                        em_loss.data.cpu().float().item(),
                        fisher_loss.cpu().float().item(),
                        fisher_intra_loss.cpu().float().item(),
                        fisher_inter_loss.cpu().float().item(),
                        ad_acc,
                        source_acc_ad,
                        target_acc_ad,
                    ))

                config['out_file'].flush()
                writer.add_scalar("training total loss",
                                  total_loss.data.cpu().float().item(),
                                  i / len(dset_loaders["source"]))
                writer.add_scalar("training classifier loss",
                                  classifier_loss.data.cpu().float().item(),
                                  i / len(dset_loaders["source"]))
                writer.add_scalar("training transfer loss",
                                  transfer_loss.data.cpu().float().item(),
                                  i / len(dset_loaders["source"]))
                writer.add_scalar("training entropy minimization loss",
                                  em_loss.data.cpu().float().item(),
                                  i / len(dset_loaders["source"]))
                writer.add_scalar("training total fisher loss",
                                  fisher_loss.data.cpu().float().item(),
                                  i / len(dset_loaders["source"]))
                writer.add_scalar("training intra-group fisher",
                                  fisher_intra_loss.data.cpu().float().item(),
                                  i / len(dset_loaders["source"]))
                writer.add_scalar("training inter-group fisher",
                                  fisher_inter_loss.data.cpu().float().item(),
                                  i / len(dset_loaders["source"]))
                writer.add_scalar("training source+target domain accuracy",
                                  ad_acc, i / len(dset_loaders["source"]))
                writer.add_scalar("training source domain accuracy",
                                  source_acc_ad,
                                  i / len(dset_loaders["source"]))
                writer.add_scalar("training target domain accuracy",
                                  target_acc_ad,
                                  i / len(dset_loaders["source"]))

                #attempted validation step
                for j in range(0, len(dset_loaders["source_valid"])):
                    base_network.train(False)
                    with torch.no_grad():

                        try:
                            inputs_valid_source, labels_valid_source = iter(
                                dset_loaders["source_valid"]).next()
                            inputs_valid_target, labels_valid_target = iter(
                                dset_loaders["target_valid"]).next()
                        except StopIteration:
                            iter(dset_loaders["source_valid"])
                            iter(dset_loaders["target_valid"])

                        if use_gpu:
                            inputs_valid_source, inputs_valid_target, labels_valid_source = \
                                Variable(inputs_valid_source).cuda(), Variable(inputs_valid_target).cuda(), \
                                Variable(labels_valid_source).cuda()
                        else:
                            inputs_valid_source, inputs_valid_target, labels_valid_source = Variable(inputs_valid_source), \
                                Variable(inputs_valid_target), Variable(labels_valid_source)

                        valid_inputs = torch.cat(
                            (inputs_valid_source, inputs_valid_target), dim=0)
                        valid_source_batch_size = inputs_valid_source.size(0)

                        if config['loss']['ly_type'] == 'cosine':
                            features, logits = base_network(valid_inputs)
                            source_logits = logits.narrow(
                                0, 0, valid_source_batch_size)
                        elif config['loss']['ly_type'] == 'euclidean':
                            features, _ = base_network(valid_inputs)
                            logits = -1.0 * loss.distance_to_centroids(
                                features, center_criterion.centers.detach())
                            source_logits = logits.narrow(
                                0, 0, valid_source_batch_size)

                        ad_net.train(False)
                        weight_ad = torch.ones(valid_inputs.size(0))
                        transfer_loss = transfer_criterion(features, ad_net, gradient_reverse_layer, \
                                   weight_ad, use_gpu)
                        ad_out, _ = ad_net(features.detach())
                        ad_acc, source_acc_ad, target_acc_ad = domain_cls_accuracy(
                            ad_out)

                        # source domain classification task loss
                        classifier_loss = class_criterion(
                            source_logits, labels_valid_source.long())

                        # fisher loss on labeled source domain
                        fisher_loss, fisher_intra_loss, fisher_inter_loss, center_grad = center_criterion(
                            features.narrow(0, 0,
                                            int(valid_inputs.size(0) / 2)),
                            labels_valid_source,
                            inter_class=loss_params["inter_type"],
                            intra_loss_weight=loss_params["intra_loss_coef"],
                            inter_loss_weight=loss_params["inter_loss_coef"])
                        # entropy minimization loss
                        em_loss = loss.EntropyLoss(nn.Softmax(dim=1)(logits))

                        # final loss
                        total_loss = loss_params["trade_off"] * transfer_loss \
                                     + fisher_loss \
                                     + loss_params["em_loss_coef"] * em_loss \
                                     + classifier_loss

                    if j % len(dset_loaders["source_valid"]) == 0:
                        config['out_file'].write(
                            'epoch {}: valid total loss={:0.4f}, valid transfer loss={:0.4f}, valid classifier loss={:0.4f}, '
                            'valid entropy min loss={:0.4f}, '
                            'valid fisher loss={:0.4f}, valid intra-group fisher loss={:0.4f}, valid inter-group fisher loss={:0.4f}, '
                            'valid source+target domain accuracy={:0.4f}, valid source domain accuracy={:0.4f}, valid target domain accuracy={:0.4f}\n'
                            .format(
                                i / len(dset_loaders["source"]),
                                total_loss.data.cpu().float().item(),
                                transfer_loss.data.cpu().float().item(),
                                classifier_loss.data.cpu().float().item(),
                                em_loss.data.cpu().float().item(),
                                fisher_loss.cpu().float().item(),
                                fisher_intra_loss.cpu().float().item(),
                                fisher_inter_loss.cpu().float().item(),
                                ad_acc,
                                source_acc_ad,
                                target_acc_ad,
                            ))

                        config['out_file'].flush()
                        writer.add_scalar("validation total loss",
                                          total_loss.data.cpu().float().item(),
                                          i / len(dset_loaders["source"]))
                        writer.add_scalar(
                            "validation classifier loss",
                            classifier_loss.data.cpu().float().item(),
                            i / len(dset_loaders["source"]))
                        writer.add_scalar(
                            "validation entropy minimization loss",
                            em_loss.data.cpu().float().item(),
                            i / len(dset_loaders["source"]))
                        writer.add_scalar(
                            "validation transfer loss",
                            transfer_loss.data.cpu().float().item(),
                            i / len(dset_loaders["source"]))
                        writer.add_scalar(
                            "validation entropy minimization loss",
                            em_loss.data.cpu().float().item(),
                            i / len(dset_loaders["source"]))
                        writer.add_scalar(
                            "validation total fisher loss",
                            fisher_loss.data.cpu().float().item(),
                            i / len(dset_loaders["source"]))
                        writer.add_scalar(
                            "validation intra-group fisher",
                            fisher_intra_loss.data.cpu().float().item(),
                            i / len(dset_loaders["source"]))
                        writer.add_scalar(
                            "validation inter-group fisher",
                            fisher_inter_loss.data.cpu().float().item(),
                            i / len(dset_loaders["source"]))
                        writer.add_scalar(
                            "validation source+target domain accuracy", ad_acc,
                            i / len(dset_loaders["source"]))
                        writer.add_scalar("validation source domain accuracy",
                                          source_acc_ad,
                                          i / len(dset_loaders["source"]))
                        writer.add_scalar("validation target domain accuracy",
                                          target_acc_ad,
                                          i / len(dset_loaders["source"]))

                        if early_stop_engine.is_stop_training(
                                classifier_loss.cpu().float().item()):
                            config["out_file"].write(
                                "no improvement after {}, stop training at step {}\n"
                                .format(config["early_stop_patience"],
                                        i / len(dset_loaders["source"])))

                            sys.exit()

    return best_acc
Exemplo n.º 10
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='CDAN USPS MNIST')
    parser.add_argument('--method',
                        type=str,
                        default='CDAN-E',
                        choices=['CDAN', 'CDAN-E', 'DANN'])
    parser.add_argument('--task',
                        default='MNIST2USPS',
                        help='MNIST2USPS or MNIST2USPS')
    parser.add_argument('--batch_size',
                        type=int,
                        default=64,
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test_batch_size',
                        type=int,
                        default=1000,
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=100,
                        metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.01,
                        metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.5,
                        metavar='M',
                        help='SGD momentum (default: 0.5)')
    parser.add_argument('--gpu_id',
                        default='0',
                        type=str,
                        help='cuda device id')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument(
        '--log_interval',
        type=int,
        default=10,
        help='how many batches to wait before logging training status')
    parser.add_argument('--random',
                        type=bool,
                        default=False,
                        help='whether to use random')
    parser.add_argument('--mdd_weight', type=float, default=0.05)
    parser.add_argument('--entropic_weight', type=float, default=0)
    parser.add_argument("--use_seed", type=bool, default=True)
    args = parser.parse_args()
    import random
    if (args.use_seed):
        torch.manual_seed(args.seed)

        np.random.seed(args.seed)
        random.seed(args.seed)
        torch.backends.cudnn.deterministic = True
    import os.path as osp
    import datetime
    config = {}
    config["output_path"] = "snapshot/" + args.task
    config['seed'] = args.seed
    config["torch_seed"] = torch.initial_seed()
    config["torch_cuda_seed"] = torch.cuda.initial_seed()

    config["mdd_weight"] = args.mdd_weight
    config["entropic_weight"] = args.entropic_weight
    if not osp.exists(config["output_path"]):
        os.system('mkdir -p ' + config["output_path"])
    config["out_file"] = open(
        osp.join(
            config["output_path"],
            "log_{}_{}.txt".format(args.task,
                                   str(datetime.datetime.utcnow()))), "w")

    torch.manual_seed(args.seed)
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id

    if args.task == 'USPS2MNIST':
        source_list = 'data/usps2mnist/usps_train.txt'
        target_list = 'data/usps2mnist/mnist_train.txt'
        test_list = 'data/usps2mnist/mnist_test.txt'
        start_epoch = 1
        decay_epoch = 6
    elif args.task == 'MNIST2USPS':
        source_list = 'data/usps2mnist/mnist_train.txt'
        target_list = 'data/usps2mnist/usps_train.txt'
        test_list = 'data/usps2mnist/usps_test.txt'
        start_epoch = 1
        decay_epoch = 5
    else:
        raise Exception('task cannot be recognized!')

    train_loader = torch.utils.data.DataLoader(ImageList(
        open(source_list).readlines(),
        transform=transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, ), (0.5, ))
        ]),
        mode='L'),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=1,
                                               drop_last=True)
    train_loader1 = torch.utils.data.DataLoader(ImageList(
        open(target_list).readlines(),
        transform=transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, ), (0.5, ))
        ]),
        mode='L'),
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                num_workers=1,
                                                drop_last=True)
    test_loader = torch.utils.data.DataLoader(ImageList(
        open(test_list).readlines(),
        transform=transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, ), (0.5, ))
        ]),
        mode='L'),
                                              batch_size=args.test_batch_size,
                                              shuffle=True,
                                              num_workers=1)

    model = network.LeNet()
    model = model.cuda()
    class_num = 10

    if args.random:
        random_layer = network.RandomLayer([model.output_num(), class_num],
                                           500)
        ad_net = network.AdversarialNetwork(500, 500)
        random_layer.cuda()
    else:
        random_layer = None
        ad_net = network.AdversarialNetwork(model.output_num() * class_num,
                                            500)
    ad_net = ad_net.cuda()
    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          weight_decay=0.0005,
                          momentum=0.9)
    optimizer_ad = optim.SGD(ad_net.parameters(),
                             lr=args.lr,
                             weight_decay=0.0005,
                             momentum=0.9)
    config["out_file"].write(str(config) + "\n")
    config["out_file"].flush()
    for epoch in range(1, args.epochs + 1):
        if epoch % decay_epoch == 0:
            for param_group in optimizer.param_groups:
                param_group["lr"] = param_group["lr"] * 0.5
        train(args, model, ad_net, random_layer, train_loader, train_loader1,
              optimizer, optimizer_ad, epoch, start_epoch, args.method)
        test(args, epoch, config, model, test_loader)
Exemplo n.º 11
0
def train(config):
    ## set pre-process
    prep_dict = {}
    prep_config = config["prep"]
    prep_dict["source"] = prep.image_train(**config["prep"]['params'])
    prep_dict["target"] = prep.image_train(**config["prep"]['params'])
    if prep_config["test_10crop"]:
        prep_dict["test"] = prep.image_test_10crop(**config["prep"]['params'])
    else:
        prep_dict["test"] = prep.image_test(**config["prep"]['params'])

    ## prepare data
    dsets = {}
    dset_loaders = {}
    data_config = config["data"]
    train_bs = data_config["source"]["batch_size"]
    test_bs = data_config["test"]["batch_size"]
    dsets["source"] = ImageList(open(data_config["source"]["list_path"]).readlines(), \
                                transform=prep_dict["source"])
    dset_loaders["source"] = DataLoader(dsets["source"], batch_size=train_bs, \
                                        shuffle=True, num_workers=0, drop_last=True)
    dsets["target"] = ImageList(open(data_config["target"]["list_path"]).readlines(), \
                                transform=prep_dict["target"])
    dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, \
                                        shuffle=True, num_workers=0, drop_last=True)

    if prep_config["test_10crop"]:
        for i in range(10):
            dsets["test"] = [ImageList(open(data_config["test"]["list_path"]).readlines(), \
                                       transform=prep_dict["test"][i]) for i in range(10)]
            dset_loaders["test"] = [DataLoader(dset, batch_size=test_bs, \
                                               shuffle=False, num_workers=0) for dset in dsets['test']]
    else:
        dsets["test"] = ImageList(open(data_config["test"]["list_path"]).readlines(), \
                                  transform=prep_dict["test"])
        dset_loaders["test"] = DataLoader(dsets["test"], batch_size=test_bs, \
                                          shuffle=False, num_workers=0)

    class_num = config["network"]["params"]["class_num"]

    ## set base network
    net_config = config["network"]
    base_network = net_config["name"](**net_config["params"])
    base_network = base_network.cuda()

    ## add additional network for some methods
    if config["loss"]["random"]:
        random_layer = network.RandomLayer(
            [base_network.output_num(), class_num],
            config["loss"]["random_dim"])
        ad_net = network.AdversarialNetwork(config["loss"]["random_dim"], 1024)
    else:
        random_layer = None
        ad_net = network.AdversarialNetwork(
            base_network.output_num() * class_num, 1024)
    if config["loss"]["random"]:
        random_layer.cuda()
    ad_net = ad_net.cuda()
    parameter_list = base_network.get_parameters() + ad_net.get_parameters()

    ## set optimizer
    optimizer_config = config["optimizer"]
    optimizer = optimizer_config["type"](parameter_list, \
                                         **(optimizer_config["optim_params"]))
    param_lr = []
    for param_group in optimizer.param_groups:
        param_lr.append(param_group["lr"])
    schedule_param = optimizer_config["lr_param"]
    lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]]

    gpus = config['gpu'].split(',')
    if len(gpus) > 1:
        ad_net = nn.DataParallel(ad_net, device_ids=[int(i) for i in gpus])
        base_network = nn.DataParallel(base_network,
                                       device_ids=[int(i) for i in gpus])

    ## train
    len_train_source = len(dset_loaders["source"])
    len_train_target = len(dset_loaders["target"])
    best_acc = 0.0
    best_model = nn.Sequential(base_network)
    each_log = ""
    for i in range(config["num_iterations"]):
        if i % config["test_interval"] == config["test_interval"] - 1:

            base_network.train(False)
            temp_acc = image_classification_test(dset_loaders, \
                                                 base_network, test_10crop=prep_config["test_10crop"])
            temp_model = nn.Sequential(base_network)
            if temp_acc > best_acc:
                best_acc = temp_acc
                best_model = temp_model
            log_str = "iter: {:05d}, precision: {:.5f}, transfer_loss:{:.4f}, classifier_loss:{:.4f}, total_loss:{:.4f}" \
                .format(i, temp_acc, transfer_loss.item(), classifier_loss.item(), total_loss.item())
            config["out_file"].write(log_str + "\n")
            config["out_file"].flush()
            print(log_str)

            config["out_file"].write(each_log)
            config["out_file"].flush()
            each_log = ""
        loss_params = config["loss"]
        ## train one iter
        base_network.train(True)
        ad_net.train(True)
        optimizer = lr_scheduler(optimizer, i, **schedule_param)
        optimizer.zero_grad()
        if i % len_train_source == 0:
            iter_source = iter(dset_loaders["source"])
        if i % len_train_target == 0:
            iter_target = iter(dset_loaders["target"])
        inputs_source, labels_source = iter_source.next()
        inputs_target, labels_target = iter_target.next()
        inputs_source, inputs_target, labels_source = inputs_source.cuda(
        ), inputs_target.cuda(), labels_source.cuda()
        features_source, outputs_source = base_network(inputs_source)
        features_target, outputs_target = base_network(inputs_target)
        features = torch.cat((features_source, features_target), dim=0)
        outputs = torch.cat((outputs_source, outputs_target), dim=0)
        softmax_out = nn.Softmax(dim=1)(outputs)
        labels_target_fake = torch.max(nn.Softmax(dim=1)(outputs_target), 1)[1]
        labels = torch.cat((labels_source, labels_target_fake))
        entropy = loss.Entropy(softmax_out)
        transfer_loss = loss.CDAN([features, softmax_out], ad_net, entropy,
                                  network.calc_coeff(i), random_layer)

        classifier_loss = nn.CrossEntropyLoss()(outputs_source, labels_source)
        mdd_loss = loss.mdd_loss(features=features,
                                 labels=labels,
                                 left_weight=args.left_weight,
                                 right_weight=args.right_weight)
        max_entropy_loss = loss.EntropicConfusion(features)
        total_loss = loss_params["trade_off"] * transfer_loss \
                     + args.cls_weight * classifier_loss \
                     + args.mdd_weight * mdd_loss \
                     + args.entropic_weight * max_entropy_loss
        total_loss.backward()
        optimizer.step()
        log_str = "iter: {:05d},transfer_loss:{:.4f}, classifier_loss:{:.4f}, mdd_loss:{:4f}," \
                  "max_entropy_loss:{:.4f},total_loss:{:.4f}" \
            .format(i, transfer_loss.item(), classifier_loss.item(), mdd_loss.item(),
                    max_entropy_loss.item(), total_loss.item())
        each_log += log_str + "\n"

    torch.save(
        best_model, config['model_output_path'] + "{}_{}_p-{}_e-{}".format(
            config['log_name'], str(best_acc), str(config["mdd_weight"]),
            str(config["entropic_weight"])))
    return best_acc
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='CDAN SVHN MNIST')
    parser.add_argument('--method',
                        type=str,
                        default='CDAN',
                        choices=['CDAN', 'CDAN-E', 'DANN'])
    parser.add_argument('--task', default='USPS2MNIST', help='task to perform')
    parser.add_argument('--batch_size',
                        type=int,
                        default=64,
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test_batch_size',
                        type=int,
                        default=1000,
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=10,
                        metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--lr', type=float, default=0.03, metavar='LR')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.5,
                        metavar='M',
                        help='SGD momentum (default: 0.5)')
    parser.add_argument('--gpu_id',
                        type=str,
                        default='0',
                        help='cuda device id')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument(
        '--log_interval',
        type=int,
        default=10,
        help='how many batches to wait before logging training status')
    parser.add_argument('--random',
                        type=bool,
                        default=False,
                        help='whether to use random')
    args = parser.parse_args()

    torch.manual_seed(args.seed)
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id

    source_list = '/data/svhn/svhn_train.txt'
    target_list = '/data/mnist/mnist_train.txt'
    test_list = '/data/mnist/mnist_test.txt'

    train_loader = torch.utils.data.DataLoader(ImageList(
        open(source_list).readlines(),
        transform=transforms.Compose([
            transforms.Resize(28),
            transforms.ToTensor(),
            transforms.Normalize((0.5, ), (0.5, ))
        ]),
        mode='RGB'),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=1)
    train_loader1 = torch.utils.data.DataLoader(ImageList(
        open(target_list).readlines(),
        transform=transforms.Compose([
            transforms.Resize(28),
            transforms.ToTensor(),
            transforms.Normalize((0.5, ), (0.5, ))
        ]),
        mode='RGB'),
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                num_workers=1)
    test_loader = torch.utils.data.DataLoader(ImageList(
        open(test_list).readlines(),
        transform=transforms.Compose([
            transforms.Resize(28),
            transforms.ToTensor(),
            transforms.Normalize((0.5, ), (0.5, ))
        ]),
        mode='RGB'),
                                              batch_size=args.test_batch_size,
                                              shuffle=True,
                                              num_workers=1)

    # img_transform_for_svhn = transforms.Compose([
    #     transforms.Resize(28),
    #     transforms.ToTensor(),
    #     transforms.Normalize(mean=(0.1307,), std=(0.3081,))
    #     # transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    #     # transforms.Normalize((0.5,), (0.5,))
    # ])
    # img_transform_for_mnist = transforms.Compose([
    #     transforms.Resize(28),
    #     transforms.ToTensor(),
    #     # transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    #     transforms.Normalize(mean=(0.1307,), std=(0.3081,))
    #     # transforms.Normalize((0.5,), (0.5,))
    # ])

    # dataset_source = datasets.SVHN(
    #     root=os.path.abspath(os.path.join(os.path.dirname(__file__), '../data/svhn2mnist')),
    #     split='train',
    #     transform=img_transform_for_svhn,
    #     download=True
    # )
    # train_loader = torch.utils.data.DataLoader(
    #     dataset=dataset_source,
    #     batch_size=args.batch_size,
    #     shuffle=True,
    #     num_workers=0)

    # dataset_target = datasets.MNIST(
    #     root=os.path.abspath(os.path.join(os.path.dirname(__file__), '../data/svhn2mnist')),
    #     train=True,
    #     transform=img_transform_for_mnist,
    #     download=True
    # )
    # train_loader1 = torch.utils.data.DataLoader(
    #     dataset=dataset_target,
    #     batch_size=args.batch_size,
    #     shuffle=True,
    #     num_workers=0)

    # dataset_test = datasets.MNIST(
    #     root=os.path.abspath(os.path.join(os.path.dirname(__file__), '../data/svhn2mnist')),
    #     train=False,
    #     transform=img_transform_for_mnist,
    #     download=True
    # )
    # test_loader = torch.utils.data.DataLoader(
    #     dataset=dataset_test,
    #     batch_size=args.batch_size,
    #     shuffle=True,
    #     num_workers=0)

    model = network.DTN()
    # model = model.cuda()
    class_num = 10

    if args.random:
        random_layer = network.RandomLayer([model.output_num(), class_num],
                                           500)
        ad_net = network.AdversarialNetwork(500, 500)
        # random_layer.cuda()
    else:
        random_layer = None
        ad_net = network.AdversarialNetwork(model.output_num() * class_num,
                                            500)
    # ad_net = ad_net.cuda()
    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          weight_decay=0.0005,
                          momentum=0.9)
    optimizer_ad = optim.SGD(ad_net.parameters(),
                             lr=args.lr,
                             weight_decay=0.0005,
                             momentum=0.9)

    for epoch in range(1, args.epochs + 1):
        if epoch % 3 == 0:
            for param_group in optimizer.param_groups:
                param_group["lr"] = param_group["lr"] * 0.3
        train(args, model, ad_net, random_layer, train_loader, train_loader1,
              optimizer, optimizer_ad, epoch, 0, args.method)
        test(args, model, test_loader)
Exemplo n.º 13
0
def train(config):
    ## set pre-process
    prep_dict = {}
    prep_config = config["prep"]
    prep_dict["source"] = prep.image_train(**config["prep"]['params'])
    prep_dict["target"] = prep.image_train(**config["prep"]['params'])
    if prep_config["test_10crop"]:
        prep_dict["test"] = prep.image_test_10crop(**config["prep"]['params'])
    else:
        prep_dict["test"] = prep.image_test(**config["prep"]['params'])

    tensor_writer = SummaryWriter(config["tensorboard_path"])

    ## prepare data
    dsets = {}
    dset_loaders = {}
    data_config = config["data"]
    train_bs = data_config["source"]["batch_size"]
    test_bs = data_config["test"]["batch_size"]
    dsets["source"] = ImageList(open(data_config["source"]["list_path"]).readlines(), \
                                transform=prep_dict["source"])
    dset_loaders["source"] = DataLoader(dsets["source"], batch_size=train_bs, \
            shuffle=True, num_workers=4, drop_last=True)

    dsets["target"] = ImageList(open(data_config["target"]["list_path"]).readlines(), \
                                transform=prep_dict["target"])
    dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, \
            shuffle=True, num_workers=4, drop_last=True)

    if prep_config["test_10crop"]:
        for i in range(10):
            dsets["test"] = [ImageList(open(data_config["test"]["list_path"]).readlines(), \
                                transform=prep_dict["test"][i]) for i in range(10)]
            dset_loaders["test"] = [DataLoader(dset, batch_size=test_bs, \
                                shuffle=False, num_workers=4) for dset in dsets['test']]
            dsets["source_val"] = [ImageList(open(data_config["source"]["list_path"]).readlines(), \
                                transform=prep_dict["test"][i]) for i in range(10)]
            dset_loaders["source_val"] = [DataLoader(dset, batch_size=test_bs, \
                                shuffle=False, num_workers=4) for dset in dsets['source_val']]
    else:
        dsets["test"] = ImageList(open(data_config["test"]["list_path"]).readlines(), \
                                transform=prep_dict["test"])
        dset_loaders["test"] = DataLoader(dsets["test"], batch_size=test_bs, \
                                shuffle=False, num_workers=4)

    class_num = config["network"]["params"]["class_num"]

    ## set base network
    net_config = config["network"]
    base_network = net_config["name"](**net_config["params"])
    base_network = base_network.cuda()

    ## add additional network for some methods
    if config["loss"]["random"]:
        random_layer = network.RandomLayer(
            [base_network.output_num(), class_num],
            config["loss"]["random_dim"])
        ad_net = network.AdversarialNetwork(config["loss"]["random_dim"], 1024)
    else:
        random_layer = None
        # ad_net = network.AdversarialNetwork(base_network.output_num() * class_num, 1024)
        ad_net = network.AdversarialNetwork(base_network.output_num(), 1024)
    if config["loss"]["random"]:
        random_layer.cuda()
    ad_net = ad_net.cuda()
    parameter_list = base_network.get_parameters() + ad_net.get_parameters()

    ## set optimizer
    optimizer_config = config["optimizer"]
    optimizer = optimizer_config["type"](parameter_list, \
                    **(optimizer_config["optim_params"]))
    param_lr = []
    for param_group in optimizer.param_groups:
        param_lr.append(param_group["lr"])
    schedule_param = optimizer_config["lr_param"]
    lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]]

    gpus = config['gpu'].split(',')
    if len(gpus) > 1:
        ad_net = nn.DataParallel(ad_net, device_ids=[int(i) for i in gpus])
        base_network = nn.DataParallel(base_network,
                                       device_ids=[int(i) for i in gpus])

    ## train
    len_train_source = len(dset_loaders["source"])
    len_train_target = len(dset_loaders["target"])
    transfer_loss_value = classifier_loss_value = total_loss_value = 0.0
    best_acc = 0.0
    for i in range(config["num_iterations"]):
        if i % config["test_interval"] == config["test_interval"] - 1:
            base_network.train(False)
            temp_acc, output, prediction, label, feature = image_classification_test(dset_loaders, \
                base_network, test_10crop=prep_config["test_10crop"])
            _, output_src, prediction_src, label_src, feature_src = image_classification_val(dset_loaders, \
                base_network, test_10crop=prep_config["test_10crop"])
            temp_model = nn.Sequential(base_network)
            if temp_acc > best_acc:
                best_acc = temp_acc
                best_model = temp_model
        if i % config["snapshot_interval"] == 0:
            torch.save(nn.Sequential(base_network), osp.join(config["output_path"], \
                "iter_{:05d}_model.pth.tar".format(i)))

        loss_params = config["loss"]
        ## train one iter
        base_network.train(True)
        ad_net.train(True)
        optimizer = lr_scheduler(optimizer, i, **schedule_param)
        optimizer.zero_grad()
        if i % len_train_source == 0:
            iter_source = iter(dset_loaders["source"])
        if i % len_train_target == 0:
            iter_target = iter(dset_loaders["target"])
        inputs_source, labels_source = iter_source.next()
        inputs_target, labels_target = iter_target.next()
        inputs_source, inputs_target, labels_source = inputs_source.cuda(
        ), inputs_target.cuda(), labels_source.cuda()
        features_source, outputs_source = base_network(inputs_source)
        features_target, outputs_target = base_network(inputs_target)
        features = torch.cat((features_source, features_target), dim=0)
        outputs = torch.cat((outputs_source, outputs_target), dim=0)
        softmax_out = nn.Softmax(dim=1)(outputs)

        target_softmax_out = nn.Softmax(dim=1)(outputs_target)
        target_entropy = EntropyLoss(target_softmax_out)

        temperature = 3.0
        outputs_target_temp = outputs_target / temperature
        target_softmax_out_temp = nn.Softmax(dim=1)(outputs_target_temp)
        target_entropy_weight = loss.Entropy(target_softmax_out_temp).detach()
        target_entropy_weight = 1 + torch.exp(-target_entropy_weight)
        target_entropy_weight = train_bs * target_entropy_weight / torch.sum(
            target_entropy_weight)

        cov_matrix_t_temp = target_softmax_out_temp.mul(
            target_entropy_weight.view(-1, 1)).transpose(
                1, 0).mm(target_softmax_out_temp)
        cov_matrix_t_temp = cov_matrix_t_temp / torch.sum(cov_matrix_t_temp,
                                                          dim=1)

        mcc_loss = (torch.sum(cov_matrix_t_temp) -
                    torch.trace(cov_matrix_t_temp)) / class_num

        classifier_loss = nn.CrossEntropyLoss()(outputs_source, labels_source)
        total_loss = classifier_loss + mcc_loss
        total_loss.backward()
        optimizer.step()

        tensor_writer.add_scalar('total_loss', total_loss, i)
        tensor_writer.add_scalar('classifier_loss', classifier_loss, i)
        tensor_writer.add_scalar('cov_matrix_penalty', mcc_loss, i)

    torch.save(best_model, osp.join(config["output_path"],
                                    "best_model.pth.tar"))
    return best_acc
Exemplo n.º 14
0
def cdan_model_fn(features, labels, mode, params):
    model_class = params["model"]
    resnet_size = params["resnet_size"]
    num_classes = params["num_classes"]
    weight_decay = params["weight_decay"]
    loss_scale = params["loss_scale"]
    momentum = params["momentum"]
    base_lr = params["base_lr"]
    batch_size = params["batch_size"]

    model = model_class(resnet_size,
                        data_format="channels_last",
                        num_classes=num_classes)

    if mode == tf.estimator.ModeKeys.PREDICT:

        logits, hidden_features = model(features,
                                        mode == tf.estimator.ModeKeys.TRAIN)
        logits = tf.cast(logits, tf.float32)

        predictions = {
            'classes': tf.argmax(logits, axis=1),
            'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
        }
        # Return the predictions and the specification for serving a SavedModel
        return tf.estimator.EstimatorSpec(
            mode=mode,
            predictions=predictions,
            export_outputs={
                'predict': tf.estimator.export.PredictOutput(predictions)
            })
    elif mode == tf.estimator.ModeKeys.TRAIN:
        s_input = features["source"]
        t_input = features["target"]
        s_label = labels
        ad_s_label = features["ad_s_label"]
        ad_t_label = features["ad_t_label"]
        model = model_class(resnet_size,
                            data_format="channels_last",
                            num_classes=num_classes)
        logits, hidden_features = model(tf.concat((s_input, t_input), 0),
                                        mode == tf.estimator.ModeKeys.TRAIN)
        logits = tf.cast(logits, tf.float32)
        ad_labels = tf.cast(tf.concat((ad_s_label, ad_t_label), 0), tf.float32)
        mid_point = tf.shape(s_input)[0]

        predictions = {
            'classes':
            tf.argmax(tf.slice(logits, [0, 0], [mid_point, num_classes]),
                      axis=1),
            'probabilities':
            tf.nn.softmax(tf.slice(logits, [0, 0], [mid_point, num_classes]),
                          name='softmax_tensor')
        }

        global_step = tf.train.get_or_create_global_step()

        ad_net = network.AdversarialNetwork(global_step)
        ad_out = ad_net(hidden_features, mode == tf.estimator.ModeKeys.TRAIN)

        cross_entropy = tf.losses.sparse_softmax_cross_entropy(logits=tf.slice(
            logits, [0, 0], [mid_point, num_classes]),
                                                               labels=s_label)
        tf.identity(cross_entropy, name='cross_entropy')
        tf.summary.scalar('cross_entropy', cross_entropy)

        adversarial_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=ad_out,
                                                    labels=ad_labels))
        tf.identity(cross_entropy, name='adversarial_loss')
        tf.summary.scalar('adversarial_loss', adversarial_loss)

        def exclude_batch_norm(name):
            return 'batch_normalization' not in name

        l2_loss = weight_decay * tf.add_n(
            # loss is computed using fp32 for numerical stability.
            [
                tf.nn.l2_loss(tf.cast(v, tf.float32))
                for v in tf.trainable_variables() if exclude_batch_norm(v.name)
            ])
        tf.summary.scalar('l2_loss', l2_loss)
        loss = cross_entropy + l2_loss + adversarial_loss

        learning_rate = inv_lr_decay(base_lr,
                                     global_step,
                                     gamma=0.001,
                                     power=0.75)
        tf.identity(learning_rate, name='learning_rate')
        tf.summary.scalar('learning_rate', learning_rate)

        optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                                               momentum=momentum)

        def _grad_filter(gvs):
            return [(g, v) for g, v in gvs if not ('dense' in v.name)
                    ], [(g, v) for g, v in gvs if 'dense' in v.name]

        if loss_scale != 1:
            scaled_grad_vars = optimizer.compute_gradients(loss * loss_scale)
            dense_grad_vars, other_grad_vars = _grad_filter(scaled_grad_vars)
            other_grad_vars = [(grad / loss_scale, var)
                               for grad, var in other_grad_vars]
            dense_grad_vars = [(grad / loss_scale * 10.0, var)
                               for grad, var in dense_grad_vars]
            minimize_op = optimizer.apply_gradients(
                dense_grad_vars + other_grad_vars, global_step)
        else:
            grad_vars = optimizer.compute_gradients(loss)
            dense_grad_vars, other_grad_vars = _grad_filter(grad_vars)
            other_grad_vars = [(grad, var) for grad, var in other_grad_vars]
            dense_grad_vars = [(grad * 10.0, var)
                               for grad, var in dense_grad_vars]
            minimize_op = optimizer.apply_gradients(
                dense_grad_vars + other_grad_vars, global_step)

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        train_op = tf.group(minimize_op, update_ops)
        accuracy = tf.metrics.accuracy(s_label, predictions['classes'])

        metrics = {'accuracy': accuracy}

        tf.identity(accuracy[1], name='train_accuracy')
        tf.summary.scalar('train_accuracy', accuracy[1])

    else:
        logits, hidden_features = model(features,
                                        mode == tf.estimator.ModeKeys.TRAIN)
        logits = tf.cast(logits, tf.float32)

        predictions = {
            'classes': tf.argmax(logits, axis=1),
            'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
        }
        train_op = None

        cross_entropy = tf.losses.sparse_softmax_cross_entropy(logits=logits,
                                                               labels=labels)
        tf.identity(cross_entropy, name='cross_entropy')
        tf.summary.scalar('cross_entropy', cross_entropy)

        def exclude_batch_norm(name):
            return 'batch_normalization' not in name

        l2_loss = weight_decay * tf.add_n(
            # loss is computed using fp32 for numerical stability.
            [
                tf.nn.l2_loss(tf.cast(v, tf.float32))
                for v in tf.trainable_variables() if exclude_batch_norm(v.name)
            ])
        tf.summary.scalar('l2_loss', l2_loss)
        loss = cross_entropy + l2_loss

        accuracy = tf.metrics.accuracy(labels, predictions['classes'])

        metrics = {'accuracy': accuracy}

        tf.identity(accuracy[1], name='train_accuracy')
        tf.summary.scalar('train_accuracy', accuracy[1])

    return tf.estimator.EstimatorSpec(mode=mode,
                                      predictions=predictions,
                                      loss=loss,
                                      train_op=train_op,
                                      eval_metric_ops=metrics)
def train(config):
    tie = 1.0
    ## set pre-process
    prep_dict = {}
    prep_config = config["prep"]
    prep_dict["source"] = prep.image_train( \
        resize_size=prep_config["resize_size"], \
        crop_size=prep_config["crop_size"])
    prep_dict["target"] = prep.image_train( \
        resize_size=prep_config["resize_size"], \
        crop_size=prep_config["crop_size"])
    if prep_config["test_10crop"]:
        prep_dict["test"] = prep.image_test_10crop( \
            resize_size=prep_config["resize_size"], \
            crop_size=prep_config["crop_size"])
    else:
        prep_dict["test"] = prep.image_test( \
            resize_size=prep_config["resize_size"], \
            crop_size=prep_config["crop_size"])

    ## set loss
    class_criterion = nn.CrossEntropyLoss()  # 分类损失
    transfer_criterion1 = loss.PADA  # 迁移学习损失(带权重),BCEloss
    balance_criterion = loss.balance_loss()
    loss_params = config["loss"]

    ## prepare data
    dsets = {}
    dset_loaders = {}
    data_config = config["data"]
    dsets["source"] = ImageList(open(data_config["source"]["list_path"]).readlines(), \
                                transform=prep_dict["source"])
    dset_loaders["source"] = util_data.DataLoader(dsets["source"], \
                                                  batch_size=data_config["source"]["batch_size"], \
                                                  shuffle=True, num_workers=4)
    dsets["target"] = ImageList(open(data_config["target"]["list_path"]).readlines(), \
                                transform=prep_dict["target"])
    dset_loaders["target"] = util_data.DataLoader(dsets["target"], \
                                                  batch_size=data_config["target"]["batch_size"], \
                                                  shuffle=True, num_workers=4)

    if prep_config["test_10crop"]:
        for i in range(10):
            dsets["test" + str(i)] = ImageList(open(data_config["test"]["list_path"]).readlines(), \
                                               transform=prep_dict["test"]["val" + str(i)])
            dset_loaders["test" + str(i)] = util_data.DataLoader(dsets["test" + str(i)], \
                                                                 batch_size=data_config["test"]["batch_size"], \
                                                                 shuffle=False, num_workers=4)

            dsets["target" + str(i)] = ImageList(open(data_config["target"]["list_path"]).readlines(), \
                                                 transform=prep_dict["test"]["val" + str(i)])
            dset_loaders["target" + str(i)] = util_data.DataLoader(dsets["target" + str(i)], \
                                                                   batch_size=data_config["test"]["batch_size"], \
                                                                   shuffle=False, num_workers=4)
    else:
        dsets["test"] = ImageList(open(data_config["test"]["list_path"]).readlines(), \
                                  transform=prep_dict["test"])
        dset_loaders["test"] = util_data.DataLoader(dsets["test"], \
                                                    batch_size=data_config["test"]["batch_size"], \
                                                    shuffle=False, num_workers=4)

        dsets["target_test"] = ImageList(open(data_config["target"]["list_path"]).readlines(), \
                                         transform=prep_dict["test"])
        dset_loaders["target_test"] = MyDataLoader(dsets["target_test"], \
                                                   batch_size=data_config["test"]["batch_size"], \
                                                   shuffle=False, num_workers=4)

    class_num = config["network"]["params"]["class_num"]

    ## set base network
    net_config = config["network"]
    base_network = net_config["name"](**net_config["params"])

    use_gpu = torch.cuda.is_available()
    if use_gpu:
        base_network = base_network.cuda()

    ## collect parameters
    if net_config["params"]["new_cls"]:
        if net_config["params"]["use_bottleneck"]:
            parameter_list = [{"params": base_network.feature_layers.parameters(), "lr": 1}, \
                              {"params": base_network.bottleneck.parameters(), "lr": 10}, \
                              {"params": base_network.fc.parameters(), "lr": 10}]
        else:
            parameter_list = [{"params": base_network.feature_layers.parameters(), "lr": 1}, \
                              {"params": base_network.fc.parameters(), "lr": 10}]
    else:
        parameter_list = [{"params": base_network.parameters(), "lr": 1}]

    ## add additional network for some methods
    class_weight = torch.from_numpy(np.array(
        [1.0] * class_num))  # 权重初始化为class_num维向量,初始值为1
    if use_gpu:
        class_weight = class_weight.cuda()
    ad_net = network.AdversarialNetwork(
        base_network.output_num())  # 鉴别器设置,输入为特征提取器的维数,输出为属于共域的可能性
    nad_net = network.NAdversarialNetwork(base_network.output_num())
    gradient_reverse_layer = network.AdversarialLayer(
        high_value=config["high"])
    if use_gpu:
        ad_net = ad_net.cuda()
        nad_net = nad_net.cuda()
    parameter_list.append({"params": ad_net.parameters(), "lr": 10})

    ## set optimizer
    optimizer_config = config["optimizer"]
    optimizer = optim_dict[optimizer_config["type"]](parameter_list, \
                                                     **(optimizer_config["optim_params"]))
    param_lr = []
    for param_group in optimizer.param_groups:
        param_lr.append(param_group["lr"])
    schedule_param = optimizer_config["lr_param"]
    lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]]

    ## train
    len_train_source = len(dset_loaders["source"]) - 1
    len_train_target = len(dset_loaders["target"]) - 1
    transfer_loss_value = classifier_loss_value = total_loss_value = attention_loss_value = 0.0
    best_acc = 0.0
    for i in range(config["num_iterations"]):
        if i % config["test_interval"] == 0:  # 测试
            base_network.train(False)  # 先对特征提取器进行训练
            temp_acc = image_classification_test(dset_loaders, \
                                                 base_network, test_10crop=prep_config["test_10crop"], \
                                                 gpu=use_gpu)  # 对第一个有监督训练器进行分类训练,提取特征
            temp_model = nn.Sequential(base_network)
            if temp_acc > best_acc:
                best_acc = temp_acc
                best_model = temp_model
            log_str = "iter: {:05d}, alpha: {:03f}, tradeoff: {:03f} ,precision: {:.5f}".format(
                i, loss_params["para_alpha"], loss_params["trade_off"],
                temp_acc)
            config["out_file"].write(log_str + '\n')
            config["out_file"].flush()
            print(log_str)
        if i % config["snapshot_interval"] == 0:  # 输出
            torch.save(nn.Sequential(base_network), osp.join(config["output_path"], \
                                                             "iter_{:05d}_model.pth.tar".format(i)))

        if i % loss_params["update_iter"] == loss_params["update_iter"] - 1:
            base_network.train(False)
            target_fc8_out = image_classification_predict(
                dset_loaders,
                base_network,
                softmax_param=config["softmax_param"]
            )  # 在对有监督学习进行了训练后,喂进去目标域数据,判断权重
            class_weight = torch.mean(
                target_fc8_out, 0)  # predict模型输出的预测向量取均值,为每一个体权重,将个体权重转化为类别权重
            class_weight = (class_weight /
                            torch.mean(class_weight)).cuda().view(-1)  # 归一化
            class_criterion = nn.CrossEntropyLoss(
                weight=class_weight)  # 这里的交叉熵损失函数就是带权重的

        ## train one iter
        base_network.train(True)
        optimizer = lr_scheduler(param_lr, optimizer, i, **schedule_param)
        optimizer.zero_grad()
        if i % len_train_source == 0:
            iter_source = iter(dset_loaders["source"])
        if i % len_train_target == 0:
            iter_target = iter(dset_loaders["target"])
        inputs_source, labels_source = iter_source.next()
        inputs_target, labels_target = iter_target.next()
        if use_gpu:
            inputs_source, inputs_target, labels_source = \
                Variable(inputs_source).cuda(), Variable(inputs_target).cuda(), \
                Variable(labels_source).cuda()
        else:
            inputs_source, inputs_target, labels_source = Variable(inputs_source), \
                                                          Variable(inputs_target), Variable(labels_source)

        inputs = torch.cat((inputs_source, inputs_target),
                           dim=0)  # 这里输入后一半目标域,前一半源域
        features, outputs = base_network(inputs)  # 从有分类网络获得特征与输出,到这里是特征提取器
        softmax_out = nn.Softmax(dim=1)(outputs).detach()
        ad_net.train(True)
        nad_net.train(True)
        weight_ad = torch.zeros(inputs.size(0))
        label_numpy = labels_source.data.cpu().numpy()
        for j in range(inputs.size(0) / 2):
            weight_ad[j] = class_weight[int(label_numpy[j])]  # 计算实际样例权重
        # print(label_numpy)
        weight_ad = weight_ad / torch.max(
            weight_ad[0:inputs.size(0) / 2])  # 权重归一化
        for j in range(inputs.size(0) / 2,
                       inputs.size(0)):  # 前一半源域,所以权重是计算的,后一半目标域,权重全为1
            weight_ad[j] = 1.0
        classifier_loss = class_criterion(
            outputs.narrow(0, 0,
                           inputs.size(0) / 2), labels_source)  # 分类损失
        total_loss = classifier_loss
        total_loss.backward()
        optimizer.step()

    torch.save(best_model, osp.join(config["output_path"],
                                    "best_model.pth.tar"))
    return best_acc
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='CDAN USPS MNIST')
    parser.add_argument('method',
                        type=str,
                        default='CDAN-E',
                        choices=[
                            'CDAN', 'CDAN-E', 'DANN', 'IWDAN', 'NANN',
                            'IWDANORACLE', 'IWCDAN', 'IWCDANORACLE',
                            'IWCDAN-E', 'IWCDAN-EORACLE'
                        ])
    parser.add_argument('--task',
                        default='mnist2usps',
                        help='task to perform',
                        choices=['usps2mnist', 'mnist2usps'])
    parser.add_argument('--batch_size',
                        type=int,
                        default=64,
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test_batch_size',
                        type=int,
                        default=1000,
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=70,
                        metavar='N',
                        help='number of epochs to train (default: 70)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.0,
                        metavar='LR',
                        help='learning rate (default: 0.02)')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.5,
                        metavar='M',
                        help='SGD momentum (default: 0.5)')
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        metavar='S',
                        help='random seed (default: 42)')
    parser.add_argument(
        '--log_interval',
        type=int,
        default=50,
        help='how many batches to wait before logging training status')
    parser.add_argument(
        '--root_folder',
        type=str,
        default='data/usps2mnist/',
        help="The folder containing the datasets and the lists")
    parser.add_argument('--output_dir',
                        type=str,
                        default='results',
                        help="output directory")
    parser.add_argument(
        "-u",
        "--mu",
        help="Hyperparameter of the coefficient of the domain adversarial loss",
        type=float,
        default=1.0)
    parser.add_argument('--ratio', type=float, default=0, help='ratio option')
    parser.add_argument('--ma',
                        type=float,
                        default=0.5,
                        help='weight for the moving average of iw')
    args = parser.parse_args()
    args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Running the JSD experiment on fewer epochs for efficiency
    if args.ratio >= 100:
        args.epochs = 25

    print('Running {} on {} for {} epochs on task {}'.format(
        args.method, args.device, args.epochs, args.task))

    # Set random number seed.
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    os.environ["CUDA_VISIBLE_DEVICES"] = '0'

    if args.task == 'usps2mnist':

        # CDAN parameters
        decay_epoch = 6
        decay_frac = 0.5
        lr = 0.02
        start_epoch = 1
        model = network.LeNet(args.ma)
        build_dataset = build_uspsmnist

        source_list = os.path.join(args.root_folder, 'usps_train.txt')
        source_path = os.path.join(args.root_folder, 'usps_train_dataset.pkl')
        target_list = os.path.join(args.root_folder, 'mnist_train.txt')
        target_path = os.path.join(args.root_folder, 'mnist_train_dataset.pkl')
        test_list = os.path.join(args.root_folder, 'mnist_test.txt')
        test_path = os.path.join(args.root_folder, 'mnist_test_dataset.pkl')

    elif args.task == 'mnist2usps':

        decay_epoch = 5
        decay_frac = 0.5
        lr = 0.02
        start_epoch = 1
        model = network.LeNet(args.ma)
        build_dataset = build_uspsmnist

        source_list = os.path.join(args.root_folder, 'mnist_train.txt')
        source_path = os.path.join(args.root_folder, 'mnist_train_dataset.pkl')
        target_list = os.path.join(args.root_folder, 'usps_train.txt')
        target_path = os.path.join(args.root_folder, 'usps_train_dataset.pkl')
        test_list = os.path.join(args.root_folder, 'usps_test.txt')
        test_path = os.path.join(args.root_folder, 'usps_test_dataset.pkl')

    else:
        raise Exception('Task cannot be recognized!')

    out_log_file = open(os.path.join(args.output_dir, "log.txt"), "w")
    out_log_file_train = open(os.path.join(args.output_dir, "log_train.txt"),
                              "w")
    if not os.path.exists(args.output_dir):
        os.mkdir(args.output_dir)

    model = model.to(args.device)
    class_num = 10

    if args.lr > 0:
        lr = args.lr

    print('Starting loading data')
    sys.stdout.flush()
    t_data = time.time()
    if os.path.exists(source_path):
        print('Found existing dataset for source')
        with open(source_path, 'rb') as f:
            [source_samples, source_labels] = pickle.load(f)
            source_samples, source_labels = torch.Tensor(source_samples).to(
                args.device), torch.LongTensor(source_labels).to(args.device)
    else:
        print('Building dataset for source and writing to {}'.format(
            source_path))
        source_samples, source_labels = build_dataset(source_list, source_path,
                                                      args.root_folder,
                                                      args.device)

    if os.path.exists(target_path):
        print('Found existing dataset for target')
        with open(target_path, 'rb') as f:
            [target_samples, target_labels] = pickle.load(f)
            target_samples, target_labels = torch.Tensor(target_samples).to(
                args.device), torch.LongTensor(target_labels).to(args.device)
    else:
        print('Building dataset for target and writing to {}'.format(
            target_path))
        target_samples, target_labels = build_dataset(target_list, target_path,
                                                      args.root_folder,
                                                      args.device)

    if os.path.exists(test_path):
        print('Found existing dataset for test')
        with open(test_path, 'rb') as f:
            [test_samples, test_labels] = pickle.load(f)
            test_samples, test_labels = torch.Tensor(test_samples).to(
                args.device), torch.LongTensor(test_labels).to(args.device)
    else:
        print('Building dataset for test and writing to {}'.format(test_path))
        test_samples, test_labels = build_dataset(test_list, test_path,
                                                  args.root_folder,
                                                  args.device)

    print('Data loaded in {}'.format(time.time() - t_data))

    if args.ratio == 1:
        # RATIO OPTION 1
        # 30% of the samples from the first 5 classes
        print('Using option 1, ie [0.3] * 5 + [1] * 5')
        ratios_source = [0.3] * 5 + [1] * 5
        ratios_target = [1] * 10
    elif args.ratio >= 200:
        s_ = subsampling[int(args.ratio) % 100]
        ratios_source = s_[0]
        ratios_target = [1] * 10
        print(
            'Using random subset ratio {} of the source, with theoretical jsd {}'
            .format(args.ratio, s_[1]))
    elif 200 > args.ratio >= 100:
        s_ = subsampling[int(args.ratio) % 100]
        ratios_source = [1] * 10
        ratios_target = s_[0]
        print(
            'Using random subset ratio {} of the target, with theoretical jsd {}'
            .format(args.ratio, s_[1]))
    else:
        # ORIGINAL DATASETS
        print('Using original datasets')
        ratios_source = [1] * 10
        ratios_target = [1] * 10
    ratios_test = ratios_target

    # Subsample dataset if need be
    source_samples, source_labels = sample_ratios(source_samples,
                                                  source_labels, ratios_source)
    target_samples, target_labels = sample_ratios(target_samples,
                                                  target_labels, ratios_target)
    test_samples, test_labels = sample_ratios(test_samples, test_labels,
                                              ratios_test)

    # compute labels distribution on the source and target domain
    source_label_distribution = np.zeros((class_num))
    for img in source_labels:
        source_label_distribution[int(img.item())] += 1
    print("Total source samples: {}".format(np.sum(source_label_distribution)),
          flush=True)
    print("Source samples per class: {}".format(source_label_distribution))
    source_label_distribution /= np.sum(source_label_distribution)
    write_list(out_log_file, source_label_distribution)
    print("Source label distribution: {}".format(source_label_distribution))
    target_label_distribution = np.zeros((class_num))
    for img in target_labels:
        target_label_distribution[int(img.item())] += 1
    print("Total target samples: {}".format(np.sum(target_label_distribution)),
          flush=True)
    print("Target samples per class: {}".format(target_label_distribution))
    target_label_distribution /= np.sum(target_label_distribution)
    write_list(out_log_file, target_label_distribution)
    print("Target label distribution: {}".format(target_label_distribution))
    test_label_distribution = np.zeros((class_num))
    for img in test_labels:
        test_label_distribution[int(img.item())] += 1
    print("Test samples per class: {}".format(test_label_distribution))
    test_label_distribution /= np.sum(test_label_distribution)
    write_list(out_log_file, test_label_distribution)
    print("Test label distribution: {}".format(test_label_distribution))
    mixture = (source_label_distribution + target_label_distribution) / 2
    jsd = (scipy.stats.entropy(source_label_distribution, qk=mixture) +
           scipy.stats.entropy(target_label_distribution, qk=mixture)) / 2
    print("JSD source to target : {}".format(jsd))
    mixture_2 = (test_label_distribution + target_label_distribution) / 2
    jsd_2 = (scipy.stats.entropy(test_label_distribution, qk=mixture_2) +
             scipy.stats.entropy(target_label_distribution, qk=mixture_2)) / 2
    print("JSD test to target : {}".format(jsd_2))
    out_wei_file = open(
        os.path.join(args.output_dir, "log_weights_{}.txt".format(jsd)), "w")
    write_list(out_wei_file, [round(x, 4) for x in source_label_distribution])
    write_list(out_wei_file, [round(x, 4) for x in target_label_distribution])
    out_wei_file.write(str(jsd) + "\n")
    true_weights = torch.tensor(target_label_distribution /
                                source_label_distribution,
                                dtype=torch.float,
                                requires_grad=False)[:, None].to(args.device)
    print("True weights : {}".format(true_weights[:, 0].cpu().numpy()))

    if 'CDAN' in args.method:
        ad_net = network.AdversarialNetwork(model.output_num() * class_num,
                                            500,
                                            sigmoid='WDANN' not in args.method)
    else:
        ad_net = network.AdversarialNetwork(model.output_num(),
                                            500,
                                            sigmoid='WDANN' not in args.method)

    ad_net = ad_net.to(args.device)

    optimizer = optim.SGD(model.parameters(),
                          lr=lr,
                          weight_decay=0.0005,
                          momentum=0.9)
    optimizer_ad = optim.SGD(ad_net.parameters(),
                             lr=lr,
                             weight_decay=0.0005,
                             momentum=0.9)

    # Maintain two quantities for the QP.
    cov_mat = torch.tensor(np.zeros((class_num, class_num), dtype=np.float32),
                           requires_grad=False).to(args.device)
    pseudo_target_label = torch.tensor(np.zeros((class_num, 1),
                                                dtype=np.float32),
                                       requires_grad=False).to(args.device)
    # Maintain one weight vector for BER.
    class_weights = torch.tensor(1.0 / source_label_distribution,
                                 dtype=torch.float,
                                 requires_grad=False).to(args.device)

    for epoch in range(1, args.epochs + 1):
        start_time_test = time.time()
        if epoch % decay_epoch == 0:
            for param_group in optimizer.param_groups:
                param_group["lr"] = param_group["lr"] * decay_frac
        test(args,
             epoch,
             model,
             test_samples,
             test_labels,
             start_time_test,
             out_log_file,
             name='Target test')
        train(args, model, ad_net, source_samples, source_labels,
              target_samples, target_labels, optimizer, optimizer_ad, epoch,
              start_epoch, args.method, source_label_distribution,
              out_wei_file, cov_mat, pseudo_target_label, class_weights,
              true_weights)
    test(args,
         epoch + 1,
         model,
         test_samples,
         test_labels,
         start_time_test,
         out_log_file,
         name='Target test')
    test(args,
         epoch + 1,
         model,
         source_samples,
         source_labels,
         start_time_test,
         out_log_file_train,
         name='Source train')
def train(config, data_import):
    class_num = config["network"]["params"]["class_num"]
    loss_params = config["loss"]

    class_criterion = nn.CrossEntropyLoss()
    transfer_criterion = loss.PADA
    center_criterion = config["loss"]["discriminant_loss"](
        num_classes=class_num,
        feat_dim=config["network"]["params"]["bottleneck_dim"])

    ## prepare data
    pristine_x, pristine_y, noisy_x, noisy_y = data_import
    dsets = {}
    dset_loaders = {}

    #sampling WOR
    pristine_indices = torch.randperm(len(pristine_x))

    pristine_x_train = pristine_x[pristine_indices]
    pristine_y_train = pristine_y[pristine_indices]

    noisy_indices = torch.randperm(len(noisy_x))
    noisy_x_train = noisy_x[noisy_indices]
    noisy_y_train = noisy_y[noisy_indices]

    dsets["source"] = TensorDataset(pristine_x_train, pristine_y_train)
    dsets["target"] = TensorDataset(noisy_x_train, noisy_y_train)

    dset_loaders["source"] = DataLoader(dsets["source"],
                                        batch_size=128,
                                        shuffle=True,
                                        num_workers=1)
    dset_loaders["target"] = DataLoader(dsets["target"],
                                        batch_size=128,
                                        shuffle=True,
                                        num_workers=1)

    config["num_iterations"] = len(
        dset_loaders["source"]) * config["epochs"] + 1
    config["test_interval"] = len(dset_loaders["source"])
    config["snapshot_interval"] = len(
        dset_loaders["source"]) * config["epochs"] * .25
    config["log_iter"] = len(dset_loaders["source"])

    ## set base network
    net_config = config["network"]
    base_network = net_config["name"](**net_config["params"])

    if config["ckpt_path"] is not None:
        ckpt = torch.load(config['ckpt_path'] + '/best_model.pth.tar',
                          map_location=torch.device('cpu'))
        base_network.load_state_dict(ckpt['base_network'])

    use_gpu = torch.cuda.is_available()

    if use_gpu:
        base_network = base_network.cuda()

    ## add additional network for some methods
    ad_net = network.AdversarialNetwork(base_network.output_num())
    gradient_reverse_layer = network.AdversarialLayer(
        high_value=config["high"])

    if use_gpu:
        ad_net = ad_net.cuda()

        ## collect parameters
    if "DeepMerge" in config["net"]:
        parameter_list = [{
            "params": base_network.parameters(),
            "lr_mult": 1,
            'decay_mult': 2
        }]
        parameter_list.append({
            "params": ad_net.parameters(),
            "lr_mult": .1,
            'decay_mult': 2
        })
        parameter_list.append({
            "params": center_criterion.parameters(),
            "lr_mult": 10,
            'decay_mult': 1
        })
    elif "ResNet18" in config["net"]:
        parameter_list = [{
            "params": base_network.parameters(),
            "lr_mult": 1,
            'decay_mult': 2
        }]
        parameter_list.append({
            "params": ad_net.parameters(),
            "lr_mult": .1,
            'decay_mult': 2
        })
        parameter_list.append({
            "params": center_criterion.parameters(),
            "lr_mult": 10,
            'decay_mult': 1
        })

        if net_config["params"]["new_cls"]:
            if net_config["params"]["use_bottleneck"]:
                parameter_list = [{"params":base_network.feature_layers.parameters(), "lr_mult":1, 'decay_mult':2}, \
                                {"params":base_network.bottleneck.parameters(), "lr_mult":10, 'decay_mult':2}, \
                                {"params":base_network.fc.parameters(), "lr_mult":10, 'decay_mult':2}]
                parameter_list.append({
                    "params": ad_net.parameters(),
                    "lr_mult": config["ad_net_mult_lr"],
                    'decay_mult': 2
                })
                parameter_list.append({
                    "params": center_criterion.parameters(),
                    "lr_mult": 10,
                    'decay_mult': 1
                })
            else:
                parameter_list = [{"params":base_network.feature_layers.parameters(), "lr_mult":1, 'decay_mult':2}, \
                                {"params":base_network.fc.parameters(), "lr_mult":10, 'decay_mult':2}]
                parameter_list.append({
                    "params": ad_net.parameters(),
                    "lr_mult": config["ad_net_mult_lr"],
                    'decay_mult': 2
                })
                parameter_list.append({
                    "params": center_criterion.parameters(),
                    "lr_mult": 10,
                    'decay_mult': 1
                })
    else:
        parameter_list = [{
            "params": base_network.parameters(),
            "lr_mult": 1,
            'decay_mult': 2
        }]
        parameter_list.append({
            "params": ad_net.parameters(),
            "lr_mult": config["ad_net_mult_lr"],
            'decay_mult': 2
        })
        parameter_list.append({
            "params": center_criterion.parameters(),
            "lr_mult": 10,
            'decay_mult': 1
        })

    ## set optimizer
    optimizer_config = config["optimizer"]
    optimizer = optim_dict[optimizer_config["type"]](parameter_list, \
                    **(optimizer_config["optim_params"]))
    param_lr = []
    for param_group in optimizer.param_groups:
        param_lr.append(param_group["lr"])
    schedule_param = optimizer_config["lr_param"]
    lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]]

    ## train
    len_train_source = len(dset_loaders["source"])
    len_train_target = len(dset_loaders["target"])

    transfer_loss_value = classifier_loss_value = total_loss_value = 0.0
    best_acc = 0.0

    for i in range(config["num_iterations"]):

        ## train one iter
        base_network.train(True)

        if i % config["log_iter"] == 0:
            optimizer = lr_scheduler(param_lr, optimizer, i,
                                     config["log_iter"], config["frozen lr"],
                                     config["cycle_length"], **schedule_param)

        if config["optimizer"]["lr_type"] == "one-cycle":
            optimizer = lr_scheduler(param_lr, optimizer, i,
                                     config["log_iter"], config["frozen lr"],
                                     config["cycle_length"], **schedule_param)
        elif config["optimizer"]["lr_type"] == "linear":
            optimizer = lr_scheduler(param_lr, optimizer, i,
                                     config["log_iter"], config["frozen lr"],
                                     config["cycle_length"], **schedule_param)

        optim = optimizer.state_dict()

        optimizer.zero_grad()

        try:
            inputs_source, labels_source = iter(dset_loaders["source"]).next()
            inputs_target, labels_target = iter(dset_loaders["target"]).next()
        except StopIteration:
            iter(dset_loaders["source"])
            iter(dset_loaders["target"])

        if use_gpu:
            inputs_source, inputs_target, labels_source = \
                Variable(inputs_source).cuda(), Variable(inputs_target).cuda(), \
                Variable(labels_source).cuda()
        else:
            inputs_source, inputs_target, labels_source = Variable(inputs_source), \
                Variable(inputs_target), Variable(labels_source)

        inputs = torch.cat((inputs_source, inputs_target), dim=0)
        source_batch_size = inputs_source.size(0)

        if config['loss']['ly_type'] == 'cosine':
            features, logits = base_network(inputs)
            source_logits = logits.narrow(0, 0, source_batch_size)
        elif config['loss']['ly_type'] == 'euclidean':
            features, _ = base_network(inputs)
            logits = -1.0 * loss.distance_to_centroids(
                features, center_criterion.centers.detach())
            source_logits = logits.narrow(0, 0, source_batch_size)

        ad_net.train(True)
        weight_ad = torch.ones(inputs.size(0))
        transfer_loss = transfer_criterion(features, ad_net, gradient_reverse_layer, \
                                            weight_ad, use_gpu)
        ad_out, _ = ad_net(features.detach())
        ad_acc, source_acc_ad, target_acc_ad = domain_cls_accuracy(ad_out)

        # source domain classification task loss
        classifier_loss = class_criterion(source_logits, labels_source.long())

        # fisher loss on labeled source domain
        if config["fisher_or_no"] == 'no':
            total_loss = loss_params["trade_off"] * transfer_loss \
            + classifier_loss

            total_loss = classifier_loss + classifier_loss * (
                0.5 - source_acc_ad)**2 + classifier_loss * (0.5 -
                                                             target_acc_ad)**2

            total_loss.backward()

            optimizer.step()

        else:
            fisher_loss, fisher_intra_loss, fisher_inter_loss, center_grad = center_criterion(
                features.narrow(0, 0, int(inputs.size(0) / 2)),
                labels_source,
                inter_class=config["loss"]["inter_type"],
                intra_loss_weight=loss_params["intra_loss_coef"],
                inter_loss_weight=loss_params["inter_loss_coef"])
            # entropy minimization loss
            em_loss = loss.EntropyLoss(nn.Softmax(dim=1)(logits))

            # final loss
            total_loss = loss_params["trade_off"] * transfer_loss \
                         + fisher_loss \
                         + loss_params["em_loss_coef"] * em_loss \
                         + classifier_loss

            total_loss = classifier_loss + fisher_loss + loss_params[
                "em_loss_coef"] * em_loss + classifier_loss * (
                    0.5 - source_acc_ad)**2 + classifier_loss * (
                        0.5 - target_acc_ad)**2

            total_loss.backward()

            if center_grad is not None:
                # clear mmc_loss
                center_criterion.centers.grad.zero_()
                # Manually assign centers gradients other than using autograd
                center_criterion.centers.backward(center_grad)

            optimizer.step()

    return (-1 * total_loss.cpu().float().item())
Exemplo n.º 18
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='CDAN USPS MNIST')
    parser.add_argument('--method',
                        type=str,
                        default='CDAN-E',
                        choices=['CDAN', 'CDAN-E', 'DANN'])
    parser.add_argument('--task', default='USPS2MNIST', help='task to perform')
    parser.add_argument('--batch_size',
                        type=int,
                        default=128,
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test_batch_size',
                        type=int,
                        default=1000,
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=550,
                        metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--lr',
                        type=float,
                        default=5e-5,
                        metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--lr2',
                        type=float,
                        default=0.005,
                        metavar='LR2',
                        help='learning rate2 (default: 0.01)')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.5,
                        metavar='M',
                        help='SGD momentum (default: 0.5)')
    parser.add_argument('--gpu_id',
                        type=str,
                        default='0',
                        help='cuda device id')
    parser.add_argument(
        '--log_interval',
        type=int,
        default=10,
        help='how many batches to wait before logging training status')
    parser.add_argument('--random',
                        type=bool,
                        default=False,
                        help='whether to use random')
    args = parser.parse_args()

    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id

    if args.task == 'USPS2MNIST':
        source_list, ordinary_train_dataset, target_list, test_list, ccp = data_loader(
            task='U2M')
        start_epoch = 50
        decay_epoch = 600
    elif args.task == 'MNIST2USPS':
        source_list, ordinary_train_dataset, target_list, test_list, ccp = data_loader(
            task='M2U')
        start_epoch = 50
        decay_epoch = 600
    else:
        raise Exception('task cannot be recognized!')

    train_loader = torch.utils.data.DataLoader(dataset=source_list,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=8,
                                               drop_last=True)
    train_loader1 = torch.utils.data.DataLoader(dataset=target_list,
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                num_workers=8,
                                                drop_last=True)
    o_train_loader = torch.utils.data.DataLoader(
        dataset=ordinary_train_dataset,
        batch_size=args.test_batch_size,
        shuffle=True,
        num_workers=8)
    test_loader = torch.utils.data.DataLoader(dataset=test_list,
                                              batch_size=args.test_batch_size,
                                              shuffle=True,
                                              num_workers=8)

    model = network.LeNet()
    model = model.cuda()
    class_num = 10

    if args.random:
        random_layer = network.RandomLayer([model.output_num(), class_num],
                                           500)
        ad_net = network.AdversarialNetwork(500, 500)
        random_layer.cuda()
    else:
        random_layer = None
        ad_net = network.AdversarialNetwork(model.output_num() * class_num,
                                            500)
    ad_net = ad_net.cuda()

    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          weight_decay=0.0005,
                          momentum=0.9)
    optimizer_ad = optim.SGD(ad_net.parameters(),
                             lr=args.lr2,
                             weight_decay=0.0005,
                             momentum=0.9)

    save_table = np.zeros(shape=(args.epochs, 3))
    for epoch in range(1, args.epochs + 1):
        if epoch % decay_epoch == 0:
            for param_group in optimizer.param_groups:
                param_group["lr"] = param_group["lr"] * 0.5
        train(args, model, ad_net, random_layer, train_loader, train_loader1,
              optimizer, optimizer_ad, epoch, start_epoch, args.method, ccp)
        acc1 = test(args, model, o_train_loader)
        acc2 = test(args, model, test_loader)
        save_table[epoch - 1, :] = epoch - 50, acc1, acc2
        np.savetxt(args.task + '_.txt', save_table, delimiter=',', fmt='%1.3f')
    np.savetxt(args.task + '_.txt', save_table, delimiter=',', fmt='%1.3f')
Exemplo n.º 19
0
def train(config):
    base_network = network.ResNetFc('ResNet50', use_bottleneck=True, bottleneck_dim=config["bottleneck_dim"], new_cls=True, class_num=config["class_num"])
    ad_net = network.AdversarialNetwork(config["bottleneck_dim"], config["hidden_dim"])

    base_network = base_network.cuda()
    ad_net = ad_net.cuda()

    parameter_list = base_network.get_parameters() + ad_net.get_parameters()

    source_path = ImageList(open(config["s_path"]).readlines(), transform=preprocess.image_train(resize_size=256, crop_size=224))
    target_path = ImageList(open(config["t_path"]).readlines(), transform=preprocess.image_train(resize_size=256, crop_size=224))
    test_path   = ImageList(open(config["t_path"]).readlines(), transform=preprocess.image_test(resize_size=256, crop_size=224))

    source_loader = DataLoader(source_path, batch_size=config["train_bs"], shuffle=True, num_workers=0, drop_last=True)
    target_loader = DataLoader(target_path, batch_size=config["train_bs"], shuffle=True, num_workers=0, drop_last=True)
    test_loader   = DataLoader(test_path, batch_size=config["test_bs"], shuffle=True, num_workers=0, drop_last=True)

    optimizer_config = config["optimizer"]
    optimizer = optimizer_config["type"](parameter_list, \
                    **(optimizer_config["optim_params"]))
    param_lr = []
    for param_group in optimizer.param_groups:
        param_lr.append(param_group["lr"])
    schedule_param = optimizer_config["lr_param"]
    lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]]

    gpus = config["gpus"].split(',')
    if len(gpus) > 1:
        ad_net = nn.DataParallel(ad_net, device_ids=[int(i) for i in gpus])
        base_network = nn.DataParallel(base_network, device_ids=[int(i) for i in gpus])


    len_train_source = len(source_loader)
    len_train_target = len(target_loader)

    transfer_loss_value = classifier_loss_value = total_loss_value = 0.0
    best_acc = 0.0
    best_model_path = None

    for i in trange(config["iterations"], leave=False):
        if i % config["test_interval"] == config["test_interval"] - 1:
            base_network.train(False)
            temp_acc = image_classification_test(test_loader, base_network)
            temp_model = nn.Sequential(base_network)
            if temp_acc > best_acc:
                best_acc = temp_acc
                best_model = copy.deepcopy(temp_model)
                best_iter = i
                if best_model_path and osp.exists(best_model_path):
                    try:
                        os.remove(best_model_path)
                    except:
                        pass
                best_model_path = osp.join(config["output_path"], "iter_{:05d}.pth.tar".format(best_iter))
                torch.save(best_model, best_model_path)
            log_str = "iter: {:05d}, precision: {:.5f}".format(i, temp_acc)
            config["out_file"].write(log_str+"\n")
            config["out_file"].flush()
            # print("cut_loss: ", cut_loss.item())
            print("mix_loss: ", mix_loss.item())
            print(log_str)

        base_network.train(True)
        ad_net.train(True)
        optimizer = lr_scheduler(optimizer, i, **schedule_param)
        optimizer.zero_grad()
        if i % len_train_source == 0:
            iter_source = iter(source_loader)
        if i % len_train_target == 0:
            iter_target = iter(target_loader)

        inputs_source, labels_source = iter_source.next()
        inputs_target, labels_target = iter_target.next()
        inputs_source, inputs_target, labels_source = inputs_source.cuda(), inputs_target.cuda(), labels_source.cuda()
        labels_src_one_hot = torch.nn.functional.one_hot(labels_source, config["class_num"]).float()

        # inputs_cut, labels_cut = cutmix(base_network, inputs_source, labels_src_one_hot, inputs_target, config["alpha"], config["class_num"])
        inputs_mix, labels_mix = mixup(base_network, inputs_source, labels_src_one_hot, inputs_target, config["alpha"], config["class_num"], config["temperature"])

        features_source, outputs_source = base_network(inputs_source)
        features_target, outputs_target = base_network(inputs_target)
        # features_cut,    outputs_cut    = base_network(inputs_cut)
        features_mix,    outputs_mix    = base_network(inputs_mix)

        features = torch.cat((features_source, features_target), dim=0)
        outputs = torch.cat((outputs_source, outputs_target), dim=0)
        softmax_out = nn.Softmax(dim=1)(outputs)

        if config["method"] == 'DANN':
            transfer_loss = loss.DANN(features, ad_net)
            # cut_loss = utils.kl_loss(outputs_cut, labels_cut.detach())
            mix_loss = utils.kl_loss(outputs_mix, labels_mix.detach())
        else:
            raise ValueError('Method cannot be recognized.')

        classifier_loss = nn.CrossEntropyLoss()(outputs_source, labels_source)
        total_loss = transfer_loss + classifier_loss + (5*mix_loss)
        total_loss.backward()
        optimizer.step()
    torch.save(best_model, osp.join(config["output_path"], "best_model.pth.tar"))
    print("Training Finished! Best Accuracy: ", best_acc)
    return best_acc
Exemplo n.º 20
0
def train(args):
    ## prepare data
    train_bs, test_bs = args.batch_size, args.batch_size * 2

    dsets = {}
    dsets["source"] = data_list.ImageList(open(args.s_dset_path).readlines(),
                                          transform=image_train())
    dsets["target"] = data_list.ImageList(open(args.t_dset_path).readlines(),
                                          transform=image_train())
    dsets["test"] = data_list.ImageList(open(args.t_dset_path).readlines(),
                                        transform=image_test())

    dset_loaders = {}
    dset_loaders["source"] = DataLoader(dsets["source"],
                                        batch_size=train_bs,
                                        shuffle=True,
                                        num_workers=args.worker,
                                        drop_last=True)
    dset_loaders["target"] = DataLoader(dsets["target"],
                                        batch_size=train_bs,
                                        shuffle=True,
                                        num_workers=args.worker,
                                        drop_last=True)
    dset_loaders["test"] = DataLoader(dsets["test"],
                                      batch_size=test_bs,
                                      shuffle=False,
                                      num_workers=args.worker)

    if "ResNet" in args.net:
        params = {
            "resnet_name": args.net,
            "use_bottleneck": True,
            "bottleneck_dim": 256,
            "new_cls": True,
            'class_num': args.class_num
        }
        base_network = network.ResNetFc(**params)

    if "VGG" in args.net:
        params = {
            "vgg_name": args.net,
            "use_bottleneck": True,
            "bottleneck_dim": 256,
            "new_cls": True,
            'class_num': args.class_num
        }
        base_network = network.VGGFc(**params)

    base_network = base_network.cuda()

    ad_net = network.AdversarialNetwork(base_network.output_num(), 1024,
                                        args.max_iterations).cuda()
    parameter_list = base_network.get_parameters() + ad_net.get_parameters()
    # ad_net = torch.nn.DataParallel(ad_net).cuda()
    # base_network = torch.nn.DataParallel(base_network).cuda()

    ## set optimizer
    optimizer_config = {
        "type": torch.optim.SGD,
        "optim_params": {
            'lr': args.lr,
            "momentum": 0.9,
            "weight_decay": 5e-4,
            "nesterov": True
        },
        "lr_type": "inv",
        "lr_param": {
            "lr": args.lr,
            "gamma": 0.001,
            "power": 0.75
        }
    }
    optimizer = optimizer_config["type"](parameter_list,
                                         **(optimizer_config["optim_params"]))

    param_lr = []
    for param_group in optimizer.param_groups:
        param_lr.append(param_group["lr"])
    schedule_param = optimizer_config["lr_param"]
    lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]]

    class_weight = None
    best_ent = 1000
    total_epochs = args.max_iterations // args.test_interval

    for i in range(args.max_iterations + 1):
        base_network.train(True)
        ad_net.train(True)
        optimizer = lr_scheduler(optimizer, i, **schedule_param)

        if (i % args.test_interval == 0 and i > 0) or (i
                                                       == args.max_iterations):
            # obtain the class-level weight and evalute the current model
            base_network.train(False)
            temp_acc, class_weight, mean_ent = image_classification(
                dset_loaders, base_network)
            class_weight = class_weight.cuda().detach()

            temp = [round(i, 4) for i in class_weight.cpu().numpy().tolist()]
            log_str = str(temp)
            args.out_file.write(log_str + "\n")
            args.out_file.flush()

            print(class_weight)
            if mean_ent < best_ent:
                best_ent, best_acc = mean_ent, temp_acc
                best_model = base_network.state_dict()
            log_str = "iter: {:05d}, precision: {:.5f}, mean_entropy: {:.5f}".format(
                i, temp_acc, mean_ent)
            args.out_file.write(log_str + "\n")
            args.out_file.flush()
            print(log_str)

        if i % args.test_interval == 0:
            if args.mu > 0:
                epoch = i // args.test_interval
                len_share = int(
                    max(0, (train_bs // args.mu) * (1 - epoch / total_epochs)))
            elif args.mu == 0:
                len_share = 0  # no augmentation
            else:
                len_share = int(train_bs // abs(args.mu))  # fixed augmentation
            log_str = "\n{}, iter: {:05d}, source/ target/ middle: {:02d} / {:02d} / {:02d}\n".format(
                args.name, i, train_bs, train_bs, len_share)
            args.out_file.write(log_str)
            args.out_file.flush()
            print(log_str)

            dset_loaders["middle"] = None
            if not len_share == 0:
                dset_loaders["middle"] = DataLoader(dsets["source"],
                                                    batch_size=len_share,
                                                    shuffle=True,
                                                    num_workers=args.worker,
                                                    drop_last=True)
                iter_middle = iter(dset_loaders["middle"])

        # train one iter
        if i % len(dset_loaders["source"]) == 0:
            iter_source = iter(dset_loaders["source"])
        if i % len(dset_loaders["target"]) == 0:
            iter_target = iter(dset_loaders["target"])
        if dset_loaders["middle"] is not None and i % len(
                dset_loaders["middle"]) == 0:
            iter_middle = iter(dset_loaders["middle"])

        inputs_source, labels_source = iter_source.next()
        inputs_target, _ = iter_target.next()
        inputs_source, inputs_target, labels_source = inputs_source.cuda(
        ), inputs_target.cuda(), labels_source.cuda()

        if class_weight is not None and args.weight_cls and class_weight[
                labels_source].sum() == 0:
            continue

        features_source, outputs_source = base_network(inputs_source)
        features_target, outputs_target = base_network(inputs_target)

        if dset_loaders["middle"] is not None:
            inputs_middle, labels_middle = iter_middle.next()
            features_middle, outputs_middle = base_network(
                inputs_middle.cuda())
            features = torch.cat(
                (features_source, features_target, features_middle), dim=0)
            outputs = torch.cat(
                (outputs_source, outputs_target, outputs_middle), dim=0)
        else:
            features = torch.cat((features_source, features_target), dim=0)
            outputs = torch.cat((outputs_source, outputs_target), dim=0)

        cls_weight = torch.ones(outputs.size(0)).cuda()
        if class_weight is not None and args.weight_aug:
            cls_weight[0:train_bs] = class_weight[labels_source]
            if dset_loaders["middle"] is not None:
                cls_weight[2 * train_bs::] = class_weight[labels_middle]

        # compute source cross-entropy loss
        if class_weight is not None and args.weight_cls:
            src_ = torch.nn.CrossEntropyLoss(reduction='none')(outputs_source,
                                                               labels_source)
            weight = class_weight[labels_source].detach()
            src_loss = torch.sum(
                weight * src_) / (1e-8 + torch.sum(weight).item())
        else:
            src_loss = torch.nn.CrossEntropyLoss()(outputs_source,
                                                   labels_source)

        softmax_out = torch.nn.Softmax(dim=1)(outputs)
        entropy = my_loss.Entropy(softmax_out)
        transfer_loss = my_loss.DANN(
            features, ad_net, entropy,
            network.calc_coeff(i, 1, 0, 10, args.max_iterations), cls_weight,
            len_share)

        softmax_tar_out = torch.nn.Softmax(dim=1)(outputs_target)
        tar_loss = torch.mean(my_loss.Entropy(softmax_tar_out))

        total_loss = src_loss + transfer_loss + args.ent_weight * tar_loss
        if args.cot_weight > 0:
            if class_weight is not None and args.weight_cls:
                cot_loss = my_loss.marginloss(
                    outputs_source,
                    labels_source,
                    args.class_num,
                    alpha=args.alpha,
                    weight=class_weight[labels_source].detach())
            else:
                cot_loss = my_loss.marginloss(outputs_source,
                                              labels_source,
                                              args.class_num,
                                              alpha=args.alpha)
            total_loss += cot_loss * args.cot_weight

        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

    torch.save(best_model, os.path.join(args.output_dir, "best_model.pt"))

    log_str = 'Acc: ' + str(np.round(best_acc * 100,
                                     2)) + "\n" + 'Mean_ent: ' + str(
                                         np.round(best_ent, 3)) + '\n'
    args.out_file.write(log_str)
    args.out_file.flush()
    print(log_str)

    return best_acc
Exemplo n.º 21
0
def train(config):
    ## set pre-process
    prep_dict = {}
    prep_config = config["prep"]
    prep_dict["source"] = prep.image_train(**config["prep"]['params'])
    prep_dict["target"] = prep.image_train(**config["prep"]['params'])
    if prep_config["test_10crop"]:
        prep_dict["test"] = prep.image_test_10crop(**config["prep"]['params'])
    else:
        prep_dict["test"] = prep.image_test(**config["prep"]['params'])

    ## prepare data
    dsets = {}
    dset_loaders = {}
    data_config = config["data"]
    train_bs = data_config["source"]["batch_size"]
    test_bs = data_config["test"]["batch_size"]

    source_list = [
        '.' + i for i in open(data_config["source"]["list_path"]).readlines()
    ]
    target_list = [
        '.' + i for i in open(data_config["target"]["list_path"]).readlines()
    ]

    dsets["source"] = ImageList(source_list, \
                                transform=prep_dict["source"])
    dset_loaders["source"] = DataLoader(dsets["source"], batch_size=train_bs, \
            shuffle=True, num_workers=config['args'].num_worker, drop_last=True)
    dsets["target"] = ImageList(target_list, \
                                transform=prep_dict["target"])
    dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, \
            shuffle=True, num_workers=config['args'].num_worker, drop_last=True)
    print("source dataset len:", len(dsets["source"]))
    print("target dataset len:", len(dsets["target"]))

    if prep_config["test_10crop"]:
        for i in range(10):
            test_list = [
                '.' + i
                for i in open(data_config["test"]["list_path"]).readlines()
            ]
            dsets["test"] = [ImageList(test_list, \
                                transform=prep_dict["test"][i]) for i in range(10)]
            dset_loaders["test"] = [DataLoader(dset, batch_size=test_bs, \
                                shuffle=False, num_workers=config['args'].num_worker) for dset in dsets['test']]
    else:
        test_list = [
            '.' + i
            for i in open(data_config["test"]["list_path"]).readlines()
        ]
        dsets["test"] = ImageList(test_list, \
                                transform=prep_dict["test"])
        dset_loaders["test"] = DataLoader(dsets["test"], batch_size=test_bs, \
                                shuffle=False, num_workers=config['args'].num_worker)

    dsets["target_label"] = ImageList_label(target_list, \
                            transform=prep_dict["target"])
    dset_loaders["target_label"] = DataLoader(dsets["target_label"], batch_size=test_bs, \
            shuffle=False, num_workers=config['args'].num_worker, drop_last=False)

    class_num = config["network"]["params"]["class_num"]

    ## set base network
    net_config = config["network"]
    base_network = net_config["name"](**net_config["params"])
    base_network = base_network.to(network.dev)
    if config["restore_path"]:
        checkpoint = torch.load(
            osp.join(config["restore_path"], "best_model.pth"))["base_network"]
        ckp = {}
        for k, v in checkpoint.items():
            if "module" in k:
                ckp[k.split("module.")[-1]] = v
            else:
                ckp[k] = v
        base_network.load_state_dict(ckp)
        log_str = "successfully restore from {}".format(
            osp.join(config["restore_path"], "best_model.pth"))
        config["out_file"].write(log_str + "\n")
        config["out_file"].flush()
        print(log_str)

    ## add additional network for some methods
    if "ALDA" in args.method:
        ad_net = network.Multi_AdversarialNetwork(base_network.output_num(),
                                                  1024, class_num)
    else:
        ad_net = network.AdversarialNetwork(base_network.output_num(), 1024)
    ad_net = ad_net.to(network.dev)
    parameter_list = base_network.get_parameters() + ad_net.get_parameters()

    ## set optimizer
    optimizer_config = config["optimizer"]
    optimizer = optimizer_config["type"](parameter_list, \
                    **(optimizer_config["optim_params"]))
    param_lr = []
    for param_group in optimizer.param_groups:
        param_lr.append(param_group["lr"])
    schedule_param = optimizer_config["lr_param"]
    lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]]

    gpus = config['gpu'].split(',')
    if len(gpus) > 1:
        ad_net = nn.DataParallel(ad_net,
                                 device_ids=[int(i) for i in range(len(gpus))])
        base_network = nn.DataParallel(
            base_network, device_ids=[int(i) for i in range(len(gpus))])

    loss_params = config["loss"]
    high = loss_params["trade_off"]
    begin_label = False
    writer = SummaryWriter(config["output_path"])

    ## train
    len_train_source = len(dset_loaders["source"])
    len_train_target = len(dset_loaders["target"])
    transfer_loss_value = classifier_loss_value = total_loss_value = 0.0
    best_acc = 0.0
    loss_value = 0
    loss_adv_value = 0
    loss_correct_value = 0
    for i in tqdm(range(config["num_iterations"]),
                  total=config["num_iterations"]):
        if i % config["test_interval"] == config["test_interval"] - 1:
            base_network.train(False)
            temp_acc = image_classification_test(dset_loaders, \
                base_network, test_10crop=prep_config["test_10crop"])
            temp_model = base_network  #nn.Sequential(base_network)
            if temp_acc > best_acc:
                best_step = i
                best_acc = temp_acc
                best_model = temp_model
                checkpoint = {
                    "base_network": best_model.state_dict(),
                    "ad_net": ad_net.state_dict()
                }
                torch.save(checkpoint,
                           osp.join(config["output_path"], "best_model.pth"))
                print(
                    "\n##########     save the best model.    #############\n")
            log_str = "iter: {:05d}, precision: {:.5f}".format(i, temp_acc)
            config["out_file"].write(log_str + "\n")
            config["out_file"].flush()
            writer.add_scalar('precision', temp_acc, i)
            print(log_str)

            print("adv_loss: {:.3f} correct_loss: {:.3f} class_loss: {:.3f}".
                  format(loss_adv_value, loss_correct_value, loss_value))
            loss_value = 0
            loss_adv_value = 0
            loss_correct_value = 0

            #show val result on tensorboard
            images_inv = prep.inv_preprocess(inputs_source.clone().cpu(), 3)
            for index, img in enumerate(images_inv):
                writer.add_image(str(index) + '/Images', img, i)

        # save the pseudo_label
        if 'PseudoLabel' in config['method'] and (
                i % config["label_interval"] == config["label_interval"] - 1):
            base_network.train(False)
            pseudo_label_list = image_label(dset_loaders, base_network, threshold=config['threshold'], \
                                out_dir=config["output_path"])
            dsets["target"] = ImageList(open(pseudo_label_list).readlines(), \
                                transform=prep_dict["target"])
            dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, \
                    shuffle=True, num_workers=config['args'].num_worker, drop_last=True)
            iter_target = iter(
                dset_loaders["target"]
            )  # replace the target dataloader with Pseudo_Label dataloader
            begin_label = True

        if i > config["stop_step"]:
            log_str = "method {}, iter: {:05d}, precision: {:.5f}".format(
                config["output_path"], best_step, best_acc)
            config["final_log"].write(log_str + "\n")
            config["final_log"].flush()
            break

        ## train one iter
        base_network.train(True)
        ad_net.train(True)
        optimizer = lr_scheduler(optimizer, i, **schedule_param)
        optimizer.zero_grad()
        if i % len_train_source == 0:
            iter_source = iter(dset_loaders["source"])
        if i % len_train_target == 0:
            iter_target = iter(dset_loaders["target"])
        inputs_source, labels_source = iter_source.next()
        inputs_target, labels_target = iter_target.next()
        inputs_source, inputs_target, labels_source = Variable(
            inputs_source).to(network.dev), Variable(inputs_target).to(
                network.dev), Variable(labels_source).to(network.dev)
        features_source, outputs_source = base_network(inputs_source)
        if args.source_detach:
            features_source = features_source.detach()
        features_target, outputs_target = base_network(inputs_target)
        features = torch.cat((features_source, features_target), dim=0)
        outputs = torch.cat((outputs_source, outputs_target), dim=0)
        softmax_out = nn.Softmax(dim=1)(outputs)
        loss_params["trade_off"] = network.calc_coeff(
            i, high=high)  #if i > 500 else 0.0
        transfer_loss = 0.0
        if 'DANN' in config['method']:
            transfer_loss = loss.DANN(features, ad_net)
        elif "ALDA" in config['method']:
            ad_out = ad_net(features)
            adv_loss, reg_loss, correct_loss = loss.ALDA_loss(
                ad_out,
                labels_source,
                softmax_out,
                weight_type=config['args'].weight_type,
                threshold=config['threshold'])
            # whether add the corrected self-training loss
            if "nocorrect" in config['args'].loss_type:
                transfer_loss = adv_loss
            else:
                transfer_loss = config['args'].adv_weight * adv_loss + config[
                    'args'].adv_weight * loss_params["trade_off"] * correct_loss
            # reg_loss is only backward to the discriminator
            if "noreg" not in config['args'].loss_type:
                for param in base_network.parameters():
                    param.requires_grad = False
                reg_loss.backward(retain_graph=True)
                for param in base_network.parameters():
                    param.requires_grad = True
        # on-line self-training
        elif 'SelfTraining' in config['method']:
            transfer_loss += loss_params["trade_off"] * loss.SelfTraining_loss(
                outputs, softmax_out, config['threshold'])
        # off-line self-training
        elif 'PseudoLabel' in config['method']:
            labels_target = labels_target.to(network.dev)
            if begin_label:
                transfer_loss += loss_params["trade_off"] * nn.CrossEntropyLoss(
                    ignore_index=-1)(outputs_target, labels_target)
            else:
                transfer_loss += 0.0 * nn.CrossEntropyLoss(ignore_index=-1)(
                    outputs_target, labels_target)

        classifier_loss = nn.CrossEntropyLoss()(outputs_source, labels_source)
        loss_value += classifier_loss.item() / config["test_interval"]
        loss_adv_value += adv_loss.item() / config["test_interval"]
        loss_correct_value += correct_loss.item() / config["test_interval"]
        total_loss = classifier_loss + transfer_loss
        total_loss.backward()
        optimizer.step()
    checkpoint = {
        "base_network": temp_model.state_dict(),
        "ad_net": ad_net.state_dict()
    }
    torch.save(checkpoint, osp.join(config["output_path"], "final_model.pth"))
    return best_acc
Exemplo n.º 22
0
def train(config):
    # set pre-process
    prep_dict = {}
    dsets = {}
    dset_loaders = {}
    data_config = config["data"]
    prep_config = config["prep"]
    prep_dict["source"] = prep.image_target(**config["prep"]['params'])
    prep_dict["target"] = prep.image_target(**config["prep"]['params'])
    prep_dict["test"] = prep.image_test(**config["prep"]['params'])

    # prepare data
    train_bs = data_config["source"]["batch_size"]
    test_bs = data_config["test"]["batch_size"]
    dsets["source"] = ImageList(open(data_config["source"]["list_path"]).readlines(),
                                transform=prep_dict["source"])
    dset_loaders["source"] = DataLoader(dsets["source"], batch_size=train_bs,
                                        shuffle=True, num_workers=4, drop_last=True)
    dsets["target"] = ImageList(open(data_config["target"]["list_path"]).readlines(),
                                transform=prep_dict["target"])
    dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs,
                                        shuffle=True, num_workers=4, drop_last=True)

    dsets["test"] = ImageList(open(data_config["test"]["list_path"]).readlines(),
                              transform=prep_dict["test"])
    dset_loaders["test"] = DataLoader(dsets["test"], batch_size=test_bs,
                                      shuffle=False, num_workers=4)

    # set base network
    class_num = config["network"]["params"]["class_num"]
    net_config = config["network"]
    base_network = net_config["name"](**net_config["params"])
    base_network = base_network.cuda()

    # add additional network for some methods
    ad_net = network.AdversarialNetwork(class_num, 1024)
    ad_net = ad_net.cuda()

    # set optimizer
    parameter_list = base_network.get_parameters() + ad_net.get_parameters()
    optimizer_config = config["optimizer"]
    optimizer = optimizer_config["type"](parameter_list,
                                         **(optimizer_config["optim_params"]))
    param_lr = []
    for param_group in optimizer.param_groups:
        param_lr.append(param_group["lr"])
    schedule_param = optimizer_config["lr_param"]
    lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]]

    # multi gpu
    gpus = config['gpu'].split(',')
    if len(gpus) > 1:
        ad_net = nn.DataParallel(ad_net, device_ids=[int(i) for i, k in enumerate(gpus)])
        base_network = nn.DataParallel(base_network, device_ids=[int(i) for i, k in enumerate(gpus)])

    # train
    len_train_source = len(dset_loaders["source"])
    len_train_target = len(dset_loaders["target"])
    transfer_loss_value = classifier_loss_value = total_loss_value = 0.0
    best_acc = 0.0
    for i in range(config["num_iterations"]):
        # test
        if i % config["test_interval"] == config["test_interval"] - 1:
            base_network.train(False)
            temp_acc = image_classification_test(dset_loaders, base_network, gvbg=config["GVBG"])
            temp_model = nn.Sequential(base_network)
            if temp_acc > best_acc:
                best_acc = temp_acc
                best_model = temp_model
            log_str = "iter: {:05d}, precision: {:.5f}".format(i, temp_acc)
            config["out_file"].write(log_str+"\n")
            config["out_file"].flush()
            print(log_str)
        # save model
        if i % config["snapshot_interval"] == 0:
            torch.save(base_network.state_dict(), osp.join(config["output_path"],
                                                           "iter_{:05d}_model.pth.tar".format(i)))

        # train one iter
        base_network.train(True)
        ad_net.train(True)
        loss_params = config["loss"]
        optimizer = lr_scheduler(optimizer, i, **schedule_param)
        optimizer.zero_grad()

        # dataloader
        if i % len_train_source == 0:
            iter_source = iter(dset_loaders["source"])
        if i % len_train_target == 0:
            iter_target = iter(dset_loaders["target"])

        # network
        inputs_source, labels_source = iter_source.next()
        inputs_target, _ = iter_target.next()
        inputs_source, inputs_target, labels_source = inputs_source.cuda(), inputs_target.cuda(), labels_source.cuda()
        features_source, outputs_source, focal_source = base_network(inputs_source, gvbg=config["GVBG"])
        features_target, outputs_target, focal_target = base_network(inputs_target, gvbg=config["GVBG"])
        features = torch.cat((features_source, features_target), dim=0)
        outputs = torch.cat((outputs_source, outputs_target), dim=0)
        focals = torch.cat((focal_source, focal_target), dim=0)
        softmax_out = nn.Softmax(dim=1)(outputs)

        # loss calculation
        transfer_loss, mean_entropy, gvbg, gvbd = loss.GVB([softmax_out, focals], ad_net, network.calc_coeff(i), GVBD=config['GVBD'])
        classifier_loss = nn.CrossEntropyLoss()(outputs_source, labels_source)
        total_loss = loss_params["trade_off"] * transfer_loss + classifier_loss + config["GVBG"] * gvbg + abs(config['GVBD']) * gvbd

        if i % config["print_num"] == 0:
            log_str = "iter: {:05d}, transferloss: {:.5f}, classifier_loss: {:.5f}, mean entropy:{:.5f}, gvbg:{:.5f}, gvbd:{:.5f}".format(i, transfer_loss, classifier_loss, mean_entropy, gvbg, gvbd)
            config["out_file"].write(log_str+"\n")
            config["out_file"].flush()
            # print(log_str)

        total_loss.backward()
        optimizer.step()
    torch.save(best_model, osp.join(config["output_path"], "best_model.pth.tar"))
    return best_acc
Exemplo n.º 23
0
def train(config):
    # set pre-process
    prep_config = config["prep"]
    prep_dict = {}
    prep_dict["source"] = prep.image_train(**config["prep"]['params'])
    prep_dict["target"] = prep.image_train(**config["prep"]['params'])
    if prep_config["test_10crop"]:
        prep_dict["test"] = prep.image_test_10crop(**config["prep"]['params'])
    else:
        prep_dict["test"] = prep.image_test(**config["prep"]['params'])

    # prepare data
    dsets = {}
    dset_loaders = {}
    data_config = config["data"]
    train_bs = data_config["source"]["batch_size"]
    test_bs = data_config["test"]["batch_size"]
    dsets["source"] = datasets.ImageFolder(data_config['source']['list_path'], transform=prep_dict["source"])
    dset_loaders['source'] = getdataloader(dsets['source'], batchsize=train_bs, num_workers=4, drop_last=True, weightsampler=True)
    dsets["target"] = datasets.ImageFolder(data_config['target']['list_path'], transform=prep_dict["target"])
    dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs,
                                        shuffle=True, num_workers=4, drop_last=True)

    if prep_config["test_10crop"]:
        for i in range(10):
            dsets["test"] = [datasets.ImageFolder(data_config['test']['list_path'],
                                                  transform=prep_dict["test"][i]) for i in range(10)]
            dset_loaders["test"] = [DataLoader(dset, batch_size=test_bs,
                                               shuffle=False, num_workers=4) for dset in dsets['test']]
    else:
        dsets["test"] = datasets.ImageFolder(data_config['test']['list_path'],
                                             transform=prep_dict["test"])
        dset_loaders["test"] = DataLoader(dsets["test"], batch_size=test_bs,
                                          shuffle=False, num_workers=4)

    class_num = config["network"]["params"]["class_num"]

    # set base network
    net_config = config["network"]
    base_network = net_config["name"](**net_config["params"])
    base_network = base_network.cuda()

    # set test_ad_net
    test_ad_net = network.AdversarialNetwork(base_network.output_num(), 1024, test_ad_net=True)
    test_ad_net = test_ad_net.cuda()

    # add additional network for some methods
    if config['method'] == 'DANN':
        random_layer = None
        ad_net = network.AdversarialNetwork(base_network.output_num(), 1024)
    elif config['method'] == 'MADA':
        random_layer = None
        ad_net = network.AdversarialNetworkClassGroup(base_network.output_num(), 1024, class_num)
    elif config['method'] == 'proposed':
        if config['loss']['random']:
            random_layer = network.RandomLayer([base_network.output_num(), class_num], config['loss']['random_dim'])
            ad_net = network.AdversarialNetwork(config['loss']['random_dim'], 1024)
            ad_net_group = network.AdversarialNetworkGroup(config['loss']['random_dim'], 256, class_num, config['center_threshold'])
        else:
            random_layer = None
            ad_net = network.AdversarialNetwork(base_network.output_num(), 1024)
            ad_net_group = network.AdversarialNetworkGroup(base_network.output_num(), 1024, class_num, config['center_threshold'])
    elif config['method'] == 'base':
        pass
    else:
        if config["loss"]["random"]:
            random_layer = network.RandomLayer([base_network.output_num(), class_num], config["loss"]["random_dim"])
            ad_net = network.AdversarialNetwork(config["loss"]["random_dim"], 1024)
        else:
            random_layer = None
            ad_net = network.AdversarialNetwork(base_network.output_num() * class_num, 1024)
    if config["loss"]["random"] and config['method'] != 'base' and config['method'] != 'DANN' and config['method'] != 'MADA':
        random_layer.cuda()
    if config['method'] != 'base':
        ad_net = ad_net.cuda()
    if config['method'] == 'proposed':
        ad_net_group = ad_net_group.cuda()

    # set parameters
    if config['method'] == 'proposed':
        parameter_list = base_network.get_parameters() + test_ad_net.get_parameters() + ad_net.get_parameters() + ad_net_group.get_parameters()
    elif config['method'] == 'base':
        parameter_list = base_network.get_parameters() + test_ad_net.get_parameters()
    elif config['method'] == 'MADA':
        parameter_list = base_network.get_parameters() + test_ad_net.get_parameters() + ad_net.get_parameters()
    else:
        parameter_list = base_network.get_parameters() + test_ad_net.get_parameters() + ad_net.get_parameters()

    # set optimizer
    optimizer_config = config["optimizer"]
    optimizer = optimizer_config["type"](parameter_list, **(optimizer_config["optim_params"]))
    param_lr = []
    for param_group in optimizer.param_groups:
        param_lr.append(param_group["lr"])
    schedule_param = optimizer_config["lr_param"]
    lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]]

    # parallel
    gpus = config['gpu'].split(',')
    if len(gpus) > 1:
        base_network = nn.DataParallel(base_network)
        test_ad_net = nn.DataParallel(test_ad_net)
        if config['method'] == 'DANN':
            ad_net = nn.DataParallel(ad_net)
        elif config['method'] == 'proposed':
            if config['loss']['random']:
                random_layer = nn.DataParallel(random_layer)
                ad_net = nn.DataParallel(ad_net)
                #将ad_net_group设置成并行将会引发error,原因可能是由于ad_net_group的输出不是tensor类型,parallel还不能支持。
                #ad_net_group = nn.DataParallel(ad_net_group)
            else:
                ad_net = nn.DataParallel(ad_net)
                #ad_net_group = nn.DataParallel(ad_net_group)
        elif config['method'] == 'base':
            pass
        else:
            # CDAN+E
            if config["loss"]["random"]:
                random_layer = nn.DataParallel(random_layer)
                ad_net = nn.DataParallel(ad_net)
            # CDAN
            else:
                ad_net = nn.DataParallel(ad_net)

    ## train
    len_train_source = len(dset_loaders["source"])
    len_train_target = len(dset_loaders["target"])
    transfer_loss_value = classifier_loss_value = total_loss_value = 0.0
    best_acc = 0.0
    for i in range(config["num_iterations"]):
        if i % config["test_interval"] == config["test_interval"] - 1:
            base_network.train(False)  # eval() == train(False) is True
            temp_acc = image_classification_test(dset_loaders, base_network, test_10crop=prep_config["test_10crop"])
            temp_model = nn.Sequential(base_network)
            if temp_acc > best_acc:
                best_acc = temp_acc
                best_model = temp_model
            log_str = "iter: {:05d}, precision: {:.5f}".format(i, temp_acc)
            config["out_file"].write(log_str + "\n")
            config["out_file"].flush()
            print(log_str)
        # if i % config["snapshot_interval"] == 0:
        #     torch.save(nn.Sequential(base_network), osp.join(config["output_path"],
        #                                                      "iter_{:05d}_model.pth.tar".format(i)))

        loss_params = config["loss"]
        # train one iter
        base_network.train(True)
        if config['method'] != 'base':
            ad_net.train(True)
        if config['method'] == 'proposed':
            ad_net_group.train(True)
        # lr_scheduler
        optimizer = lr_scheduler(optimizer, i, **schedule_param)
        optimizer.zero_grad()
        if i % len_train_source == 0:
            iter_source = iter(dset_loaders["source"])
        if i % len_train_target == 0:
            iter_target = iter(dset_loaders["target"])
        inputs_source, labels_source = iter_source.next()
        inputs_target, labels_target = iter_target.next()
        inputs_source, inputs_target, labels_source = inputs_source.cuda(), inputs_target.cuda(), labels_source.cuda()
        features_source, outputs_source = base_network(inputs_source)
        features_target, outputs_target = base_network(inputs_target)
        if config['tsne']:
            # feature visualization by using T-SNE
            if i == int(0.98*config['num_iterations']):
                features_source_total = features_source.cpu().detach().numpy()
                features_target_total = features_target.cpu().detach().numpy()
            elif i > int(0.98*config['num_iterations']) and i < int(0.98*config['num_iterations'])+10:
                features_source_total = np.concatenate((features_source_total, features_source.cpu().detach().numpy()))
                features_target_total = np.concatenate((features_target_total, features_target.cpu().detach().numpy()))
            elif i == int(0.98*config['num_iterations'])+10:
                for index in range(config['tsne_num']):
                    features_embeded = TSNE(perplexity=10,n_iter=5000).fit_transform(np.concatenate((features_source_total, features_target_total)))
                    fig = plt.figure()
                    plt.scatter(features_embeded[:len(features_embeded)//2, 0], features_embeded[:len(features_embeded)//2, 1], c='r', s=1)
                    plt.scatter(features_embeded[len(features_embeded)//2:, 0], features_embeded[len(features_embeded)//2:, 1], c='b', s=1)
                    plt.savefig(osp.join(config["output_path"], config['method']+'-'+str(index)+'.png'))
                    plt.close()
            else:
                pass

        assert features_source.size(0) == features_target.size(0), 'The batchsize must be same'
        assert outputs_source.size(0) == outputs_target.size(0), 'The batchsize must be same'
        # source first, target second
        features = torch.cat((features_source, features_target), dim=0)
        outputs = torch.cat((outputs_source, outputs_target), dim=0)

        # output the A_distance
        if i % config["test_interval"] == config["test_interval"] - 1:
            A_distance = cal_A_distance(test_ad_net, features)
            config['A_distance_file'].write(str(A_distance)+'\n')
            config['A_distance_file'].flush()

        softmax_out = nn.Softmax(dim=1)(outputs)
        if config['method'] == 'CDAN+E':
            entropy = loss.Entropy(softmax_out)
            transfer_loss = loss.CDAN([features, softmax_out], ad_net, entropy, network.calc_coeff(i), random_layer)
        elif config['method'] == 'CDAN':
            transfer_loss = loss.CDAN([features, softmax_out], ad_net, None, None, random_layer)
        elif config['method'] == 'DANN':
            transfer_loss = loss.DANN(features, ad_net)
        elif config['method'] == 'MADA':
            transfer_loss = loss.MADA(features, softmax_out, ad_net)
        elif config['method'] == 'proposed':
            entropy = loss.Entropy(softmax_out)
            transfer_loss = loss.proposed([features, outputs], labels_source, ad_net, ad_net_group, entropy,
                                          network.calc_coeff(i), i, random_layer, config['loss']['trade_off23'])
        elif config['method'] == 'base':
            pass
        else:
            raise ValueError('Method cannot be recognized.')
        test_domain_loss = loss.DANN(features.clone().detach(), test_ad_net)
        classifier_loss = nn.CrossEntropyLoss()(outputs_source, labels_source)
        if config['method'] == 'base':
            total_loss = classifier_loss + test_domain_loss
        else:
            total_loss = loss_params["trade_off"] * transfer_loss + classifier_loss + test_domain_loss
        total_loss.backward()
        optimizer.step()
    # torch.save(best_model, osp.join(config["output_path"], "best_model.pth.tar"))
    return best_acc
Exemplo n.º 24
0
def train(config):
    ## set pre-process
    prep_dict = {}
    prep_config = config["prep"]
    prep_dict["source"] = prep.image_train(**config["prep"]['params'])
    prep_dict["target"] = prep.image_train(**config["prep"]['params'])
    if prep_config["test_10crop"]:
        prep_dict["test"] = prep.image_test_10crop(**config["prep"]['params'])
    else:
        prep_dict["test"] = prep.image_test(**config["prep"]['params'])

    ## prepare data
    dsets = {}
    dset_loaders = {}
    data_config = config["data"]
    train_bs = data_config["source"]["batch_size"]
    test_bs = data_config["test"]["batch_size"]
    dsets["source"] = ImageList(open(data_config["source"]["list_path"]).readlines(), \
                                transform=prep_dict["source"])
    dset_loaders["source"] = DataLoader(dsets["source"], batch_size=train_bs, \
                                        shuffle=True, num_workers=4, drop_last=True)
    dsets["target"] = ImageList(open(data_config["target"]["list_path"]).readlines(), \
                                transform=prep_dict["target"])
    dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, \
                                        shuffle=True, num_workers=4, drop_last=True)

    if prep_config["test_10crop"]:
        for i in range(10):
            dsets["test"] = [ImageList(open(data_config["test"]["list_path"]).readlines(), \
                                       transform=prep_dict["test"][i]) for i in range(10)]
            dset_loaders["test"] = [DataLoader(dset, batch_size=test_bs, \
                                               shuffle=False, num_workers=4) for dset in dsets['test']]
    else:
        dsets["test"] = ImageList(open(data_config["test"]["list_path"]).readlines(), \
                                  transform=prep_dict["test"])
        dset_loaders["test"] = DataLoader(dsets["test"], batch_size=test_bs, \
                                          shuffle=False, num_workers=4)

    class_num = config["network"]["params"]["class_num"]
    crit = LabelSmoothingLoss(smoothing=0.1, classes=class_num)#标签平滑操作

    ## set base network
    net_config = config["network"]
    base_network = net_config["name"](**net_config["params"])
    base_network = base_network.cuda()  # 加载基础网络结构

    ## add additional network for some methods
    if config["loss"]["random"]:
        random_layer = network.RandomLayer([base_network.output_num(), class_num], config["loss"]["random_dim"])
        ad_net = network.AdversarialNetwork(config["loss"]["random_dim"], 1024)
    else:
        random_layer = None
        ad_net = network.AdversarialNetwork(base_network.output_num() * class_num, 1024)  # 对抗网络结构
    if config["loss"]["random"]:
        random_layer.cuda()
    ad_net = ad_net.cuda()
    parameter_list = base_network.get_parameters() + ad_net.get_parameters()

    ## set optimizer
    optimizer_config = config["optimizer"]
    optimizer = optimizer_config["type"](parameter_list, \
                                         **(optimizer_config["optim_params"]))

    #中心损失函数
    criterion_centor=CenterLoss(num_classes=class_num,feat_dim=256,use_gpu=True)
    optimizer_centerloss=torch.optim.SGD(criterion_centor.parameters(),lr=config['lr'])

    param_lr = []
    for param_group in optimizer.param_groups:
        param_lr.append(param_group["lr"])
    schedule_param = optimizer_config["lr_param"]
    lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]]

    gpus = config['gpu'].split(',')
    if len(gpus) > 1:
        ad_net = nn.DataParallel(ad_net, device_ids=[int(i) for i in gpus])
        base_network = nn.DataParallel(base_network, device_ids=[int(i) for i in gpus])

    ## train
    len_train_source = len(dset_loaders["source"])
    len_train_target = len(dset_loaders["target"])
    transfer_loss_value = classifier_loss_value = total_loss_value = 0.0
    best_acc = 0.0
    start_time = time.time()
    for i in range(config["num_iterations"]):

        if i % config["test_interval"] == config["test_interval"] - 1:
            # 在这里进行测试的工作
            base_network.train(False)
            temp_acc = image_classification_test(dset_loaders, \
                                                 base_network, test_10crop=prep_config["test_10crop"])
            temp_model = nn.Sequential(base_network)
            if temp_acc > best_acc:
                best_acc = temp_acc
                best_model = temp_model
            log_str = "iter: {:05d}, precision: {:.5f}".format(i, temp_acc)
            config["out_file"].write(log_str + "\n")
            config["out_file"].flush()
            print(log_str)
            end_time = time.time()
            print('iter {} cost time {:.4f} sec.'.format(i, end_time - start_time))  # 打印时间间隔
            start_time = time.time()

        if i % config["snapshot_interval"] == 0:
            torch.save(nn.Sequential(base_network), osp.join(config["output_path"], \
                                                             "iter_{:05d}_model.pth.tar".format(i)))

        loss_params = config["loss"]

        ## train one iter
        base_network.train(True)  # 训练模式
        ad_net.train(True)

        optimizer = lr_scheduler(optimizer, i, **schedule_param)
        # optimizer_centerloss=lr_scheduler(optimizer_centerloss, i, **schedule_param)

        optimizer.zero_grad()
        optimizer_centerloss.zero_grad()

        if i % len_train_source == 0:
            iter_source = iter(dset_loaders["source"])
        if i % len_train_target == 0:
            iter_target = iter(dset_loaders["target"])
        inputs_source, labels_source = iter_source.next()
        inputs_target, labels_target = iter_target.next()
        inputs_source, inputs_target, labels_source = inputs_source.cuda(), inputs_target.cuda(), labels_source.cuda()
        features_source, outputs_source = base_network(inputs_source)
        features_target, outputs_target = base_network(inputs_target)
        features = torch.cat((features_source, features_target), dim=0)
        outputs = torch.cat((outputs_source, outputs_target), dim=0)
        softmax_out = nn.Softmax(dim=1)(outputs)
        if config['method'] == 'CDAN+E':
            entropy = loss.Entropy(softmax_out)
            transfer_loss = loss.CDAN([features, softmax_out], ad_net, entropy, network.calc_coeff(i), random_layer)
        elif config['method'] == 'CDAN':
            transfer_loss = loss.CDAN([features, softmax_out], ad_net, None, None, random_layer)
        elif config['method'] == 'DANN':
            transfer_loss = loss.DANN(features, ad_net)
        else:
            raise ValueError('Method cannot be recognized.')
        # classifier_loss = nn.CrossEntropyLoss()(outputs_source, labels_source)  # 源域的分类损失
        classifier_loss = crit(outputs_source, labels_source)  # 源域的分类损失,标签平滑操作

        # 计算中心损失函数
        loss_centor = criterion_centor(features_source, labels_source)  # 中心损失计算

        total_loss = loss_params["trade_off"] * transfer_loss + classifier_loss + config['centor_w']*loss_centor
        if i % config["test_interval"] == config["test_interval"] - 1:
            print('total loss: {:.4f}, transfer loss: {:.4f}, classifier loss: {:.4f}, centor loss: {:.4f}'.format(
                total_loss.item(),transfer_loss.item(),classifier_loss.item(),config['centor_w']*loss_centor.item()
            ))
        total_loss.backward()
        optimizer.step()
        # by doing so, weight_cent would not impact on the learning of centers
        for param in criterion_centor.parameters():
            param.grad.data *= (1. / config['centor_w'])
        optimizer_centerloss.step()

    torch.save(best_model, osp.join(config["output_path"], "best_model.pth.tar"))
    return best_acc
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='CDAN USPS MNIST')
    parser.add_argument('--method', type=str, default='CDAN-E', choices=['CDAN', 'CDAN-E', 'DANN'])
    parser.add_argument('--task', default='MNIST2USPS', help='task to perform')
    parser.add_argument('--batch_size', type=int, default=64,
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test_batch_size', type=int, default=1000,
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=10, metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
                        help='SGD momentum (default: 0.5)')
    parser.add_argument('--gpu_id', type=str, default='0',
                        help='cuda device id')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--log_interval', type=int, default=10,
                        help='how many batches to wait before logging training status')
    parser.add_argument('--random', type=bool, default=False,
                        help='whether to use random')
    args = parser.parse_args()

    torch.manual_seed(args.seed)
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id


    source_list = '/data/mnist/mnist_train.txt'
    target_list = '/data/usps/usps_train.txt'
    test_list = '/data/usps/usps_test.txt'
    start_epoch = 1
    decay_epoch = 5


    train_loader = torch.utils.data.DataLoader(
        ImageList(open(source_list).readlines(), transform=transforms.Compose([
                           transforms.Resize((28,28)),
                           transforms.ToTensor(),
                           transforms.Normalize((0.5,), (0.5,))
                       ]), mode='L'),
        batch_size=args.batch_size, shuffle=True, num_workers=1, drop_last=True)
    train_loader1 = torch.utils.data.DataLoader(
        ImageList(open(target_list).readlines(), transform=transforms.Compose([
                           transforms.Resize((28,28)),
                           transforms.ToTensor(),
                           transforms.Normalize((0.5,), (0.5,))
                       ]), mode='L'),
        batch_size=args.batch_size, shuffle=True, num_workers=1, drop_last=True)
    test_loader = torch.utils.data.DataLoader(
        ImageList(open(test_list).readlines(), transform=transforms.Compose([
                           transforms.Resize((28,28)),
                           transforms.ToTensor(),
                           transforms.Normalize((0.5,), (0.5,))
                       ]), mode='L'),
        batch_size=args.test_batch_size, shuffle=True, num_workers=1)

    model = network.LeNet()
    # model = model.cuda()
    class_num = 10

    if args.random:
        random_layer = network.RandomLayer([model.output_num(), class_num], 500)
        ad_net = network.AdversarialNetwork(500, 500)
        # random_layer.cuda()
    else:
        random_layer = None
        ad_net = network.AdversarialNetwork(model.output_num() * class_num, 500)
    # ad_net = ad_net.cuda()
    optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=0.0005, momentum=0.9)
    optimizer_ad = optim.SGD(ad_net.parameters(), lr=args.lr, weight_decay=0.0005, momentum=0.9)

    for epoch in range(1, args.epochs + 1):
        if epoch % decay_epoch == 0:
            for param_group in optimizer.param_groups:
                param_group["lr"] = param_group["lr"] * 0.5
        train(args, model, ad_net, random_layer, train_loader, train_loader1, optimizer, optimizer_ad, epoch, start_epoch, args.method)
        test(args, model, test_loader)
Exemplo n.º 26
0
def train(config):
    ## set pre-process
    prep_dict = {}
    prep_config = config["prep"]
    prep_dict["source"] = prep.image_train( \
                            resize_size=prep_config["resize_size"], \
                            crop_size=prep_config["crop_size"])
    prep_dict["target"] = prep.image_train( \
                            resize_size=prep_config["resize_size"], \
                            crop_size=prep_config["crop_size"])
    if prep_config["test_10crop"]:
        prep_dict["test"] = prep.image_test_10crop( \
                            resize_size=prep_config["resize_size"], \
                            crop_size=prep_config["crop_size"])
    else:
        prep_dict["test"] = prep.image_test( \
                            resize_size=prep_config["resize_size"], \
                            crop_size=prep_config["crop_size"])

    ## set loss
    class_criterion = nn.CrossEntropyLoss()
    transfer_criterion = loss.PADA
    loss_params = config["loss"]

    ## prepare data
    dsets = {}
    dset_loaders = {}
    data_config = config["data"]
    dsets["source"] = ImageList(open(data_config["source"]["list_path"]).readlines(), \
                                transform=prep_dict["source"])
    dset_loaders["source"] = util_data.DataLoader(dsets["source"], \
            batch_size=data_config["source"]["batch_size"], \
            shuffle=True, num_workers=4)
    dsets["target"] = ImageList(open(data_config["target"]["list_path"]).readlines(), \
                                transform=prep_dict["target"])
    dset_loaders["target"] = util_data.DataLoader(dsets["target"], \
            batch_size=data_config["target"]["batch_size"], \
            shuffle=True, num_workers=4)

    if prep_config["test_10crop"]:
        for i in range(10):
            dsets["test"+str(i)] = ImageList(open(data_config["test"]["list_path"]).readlines(), \
                                transform=prep_dict["test"]["val"+str(i)])
            dset_loaders["test"+str(i)] = util_data.DataLoader(dsets["test"+str(i)], \
                                batch_size=data_config["test"]["batch_size"], \
                                shuffle=False, num_workers=4)

            dsets["target"+str(i)] = ImageList(open(data_config["target"]["list_path"]).readlines(), \
                                transform=prep_dict["test"]["val"+str(i)])
            dset_loaders["target"+str(i)] = util_data.DataLoader(dsets["target"+str(i)], \
                                batch_size=data_config["test"]["batch_size"], \
                                shuffle=False, num_workers=4)
    else:
        dsets["test"] = ImageList(open(data_config["test"]["list_path"]).readlines(), \
                                transform=prep_dict["test"])
        dset_loaders["test"] = util_data.DataLoader(dsets["test"], \
                                batch_size=data_config["test"]["batch_size"], \
                                shuffle=False, num_workers=4)

        dsets["target_test"] = ImageList(open(data_config["target"]["list_path"]).readlines(), \
                                transform=prep_dict["test"])
        dset_loaders["target_test"] = MyDataLoader(dsets["target_test"], \
                                batch_size=data_config["test"]["batch_size"], \
                                shuffle=False, num_workers=4)

    class_num = config["network"]["params"]["class_num"]

    ## set base network
    net_config = config["network"]
    base_network = net_config["name"](**net_config["params"])

    use_gpu = torch.cuda.is_available()
    if use_gpu:
        base_network = base_network.cuda()

    ## collect parameters
    if net_config["params"]["new_cls"]:
        if net_config["params"]["use_bottleneck"]:
            parameter_list = [{"params":base_network.feature_layers.parameters(), "lr":1}, \
                            {"params":base_network.bottleneck.parameters(), "lr":10}, \
                            {"params":base_network.fc.parameters(), "lr":10}]
        else:
            parameter_list = [{"params":base_network.feature_layers.parameters(), "lr":1}, \
                            {"params":base_network.fc.parameters(), "lr":10}]
    else:
        parameter_list = [{"params": base_network.parameters(), "lr": 1}]

    ## add additional network for some methods
    class_weight = torch.from_numpy(np.array([1.0] * class_num))
    if use_gpu:
        class_weight = class_weight.cuda()
    ad_net = network.AdversarialNetwork(base_network.output_num())
    gradient_reverse_layer = network.AdversarialLayer(
        high_value=config["high"])
    if use_gpu:
        ad_net = ad_net.cuda()
    parameter_list.append({"params": ad_net.parameters(), "lr": 10})

    ## set optimizer
    optimizer_config = config["optimizer"]
    optimizer = optim_dict[optimizer_config["type"]](parameter_list, \
                    **(optimizer_config["optim_params"]))
    param_lr = []
    for param_group in optimizer.param_groups:
        param_lr.append(param_group["lr"])
    schedule_param = optimizer_config["lr_param"]
    lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]]

    ## train
    len_train_source = len(dset_loaders["source"]) - 1
    len_train_target = len(dset_loaders["target"]) - 1
    transfer_loss_value = classifier_loss_value = total_loss_value = 0.0
    best_acc = 0.0
    for i in range(config["num_iterations"]):
        if i % config["test_interval"] == 0:
            base_network.train(False)
            temp_acc = image_classification_test(dset_loaders, \
                base_network, test_10crop=prep_config["test_10crop"], \
                gpu=use_gpu)
            temp_model = nn.Sequential(base_network)
            if temp_acc > best_acc:
                best_acc = temp_acc
                best_model = temp_model
            log_str = "iter: {:05d}, precision: {:.5f}".format(i, temp_acc)
            config["out_file"].write(log_str)
            config["out_file"].flush()
            print(log_str)
        if i % config["snapshot_interval"] == 0:
            torch.save(nn.Sequential(base_network), osp.join(config["output_path"], \
                "iter_{:05d}_model.pth.tar".format(i)))

        ## train one iter
        base_network.train(True)
        optimizer = lr_scheduler(param_lr, optimizer, i, **schedule_param)
        optimizer.zero_grad()
        if i % len_train_source == 0:
            iter_source = iter(dset_loaders["source"])
        if i % len_train_target == 0:
            iter_target = iter(dset_loaders["target"])
        inputs_source, labels_source = iter_source.next()
        inputs_target, labels_target = iter_target.next()
        if use_gpu:
            inputs_source, inputs_target, labels_source = \
                Variable(inputs_source).cuda(), Variable(inputs_target).cuda(), \
                Variable(labels_source).cuda()
        else:
            inputs_source, inputs_target, labels_source = Variable(inputs_source), \
                Variable(inputs_target), Variable(labels_source)

        inputs = torch.cat((inputs_source, inputs_target), dim=0)
        features, outputs = base_network(inputs)

        softmax_out = nn.Softmax(dim=1)(outputs).detach()
        ad_net.train(True)
        transfer_loss = transfer_criterion(features, ad_net,
                                           gradient_reverse_layer, use_gpu)

        classifier_loss = class_criterion(
            outputs.narrow(0, 0,
                           inputs.size(0) / 2), labels_source)

        total_loss = loss_params["trade_off"] * transfer_loss + classifier_loss
        total_loss.backward()
        optimizer.step()
    torch.save(best_model, osp.join(config["output_path"],
                                    "best_model.pth.tar"))
    return best_acc
Exemplo n.º 27
0
def train(config):
    ## set up summary writer
    writer = SummaryWriter(config['output_path'])

    # set up early stop
    early_stop_engine = EarlyStopping(config["early_stop_patience"])

    ## set pre-process
    prep_dict = {}
    prep_config = config["prep"]
    prep_dict["source"] = prep.image_train( \
                            resize_size=prep_config["resize_size"], \
                            crop_size=prep_config["crop_size"])
    prep_dict["target"] = prep.image_train( \
                            resize_size=prep_config["resize_size"], \
                            crop_size=prep_config["crop_size"])
    if prep_config["test_10crop"]:
        prep_dict["test"] = prep.image_test_10crop( \
                            resize_size=prep_config["resize_size"], \
                            crop_size=prep_config["crop_size"])
    else:
        prep_dict["test"] = prep.image_test( \
                            resize_size=prep_config["resize_size"], \
                            crop_size=prep_config["crop_size"])
               
    ## set loss
    class_num = config["network"]["params"]["class_num"]
    loss_params = config["loss"]

    class_criterion = nn.CrossEntropyLoss()
    transfer_criterion = loss.PADA
    center_criterion = loss_params["loss_type"](num_classes=class_num, 
                                       feat_dim=config["network"]["params"]["bottleneck_dim"])

    ## prepare data
    dsets = {}
    dset_loaders = {}
    data_config = config["data"]
    dsets["source"] = ImageList(stratify_sampling(open(data_config["source"]["list_path"]).readlines(), prep_config["source_size"]), \
                                transform=prep_dict["source"])
    dset_loaders["source"] = util_data.DataLoader(dsets["source"], \
            batch_size=data_config["source"]["batch_size"], \
            shuffle=True, num_workers=1)
    dsets["target"] = ImageList(stratify_sampling(open(data_config["target"]["list_path"]).readlines(), prep_config["target_size"]), \
                                transform=prep_dict["target"])
    dset_loaders["target"] = util_data.DataLoader(dsets["target"], \
            batch_size=data_config["target"]["batch_size"], \
            shuffle=True, num_workers=1)

    if prep_config["test_10crop"]:
        for i in range(10):
            dsets["test"+str(i)] = ImageList(stratify_sampling(open(data_config["test"]["list_path"]).readlines(), ratio=prep_config['target_size']), \
                                transform=prep_dict["test"]["val"+str(i)])
            dset_loaders["test"+str(i)] = util_data.DataLoader(dsets["test"+str(i)], \
                                batch_size=data_config["test"]["batch_size"], \
                                shuffle=False, num_workers=1)

            dsets["target"+str(i)] = ImageList(stratify_sampling(open(data_config["target"]["list_path"]).readlines(), ratio=prep_config['target_size']), \
                                transform=prep_dict["test"]["val"+str(i)])
            dset_loaders["target"+str(i)] = util_data.DataLoader(dsets["target"+str(i)], \
                                batch_size=data_config["test"]["batch_size"], \
                                shuffle=False, num_workers=1)
    else:
        dsets["test"] = ImageList(stratify_sampling(open(data_config["test"]["list_path"]).readlines(), ratio=prep_config['target_size']), \
                                transform=prep_dict["test"])
        dset_loaders["test"] = util_data.DataLoader(dsets["test"], \
                                batch_size=data_config["test"]["batch_size"], \
                                shuffle=False, num_workers=1)

        dsets["target_test"] = ImageList(stratify_sampling(open(data_config["target"]["list_path"]).readlines(), ratio=prep_config['target_size']), \
                                transform=prep_dict["test"])
        dset_loaders["target_test"] = MyDataLoader(dsets["target_test"], \
                                batch_size=data_config["test"]["batch_size"], \
                                shuffle=False, num_workers=1)

    config['out_file'].write("dataset sizes: source={}, target={}\n".format(
        len(dsets["source"]), len(dsets["target"])))

    ## set base network
    net_config = config["network"]
    base_network = net_config["name"](**net_config["params"])


    use_gpu = torch.cuda.is_available()
    if use_gpu:
        base_network = base_network.cuda()

    ## collect parameters
    if net_config["params"]["new_cls"]:
        if net_config["params"]["use_bottleneck"]:
            parameter_list = [{"params":base_network.feature_layers.parameters(), "lr_mult":1, 'decay_mult':2}, \
                            {"params":base_network.bottleneck.parameters(), "lr_mult":10, 'decay_mult':2}, \
                            {"params":base_network.fc.parameters(), "lr_mult":10, 'decay_mult':2}]
        else:
            parameter_list = [{"params":base_network.feature_layers.parameters(), "lr_mult":1, 'decay_mult':2}, \
                            {"params":base_network.fc.parameters(), "lr_mult":10, 'decay_mult':2}]
    else:
        parameter_list = [{"params":base_network.parameters(), "lr_mult":1, 'decay_mult':2}]

    ## add additional network for some methods
    ad_net = network.AdversarialNetwork(base_network.output_num())
    gradient_reverse_layer = network.AdversarialLayer(high_value=config["high"]) #, 
                                                      #max_iter_value=config["num_iterations"])
    if use_gpu:
        ad_net = ad_net.cuda()
    parameter_list.append({"params":ad_net.parameters(), "lr_mult":10, 'decay_mult':2})
    parameter_list.append({"params":center_criterion.parameters(), "lr_mult": 10, 'decay_mult':1})
 
    ## set optimizer
    optimizer_config = config["optimizer"]
    optimizer = optim_dict[optimizer_config["type"]](parameter_list, \
                    **(optimizer_config["optim_params"]))
    param_lr = []
    for param_group in optimizer.param_groups:
        param_lr.append(param_group["lr"])
    schedule_param = optimizer_config["lr_param"]
    lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]]


    ## train   
    len_train_source = len(dset_loaders["source"]) - 1
    len_train_target = len(dset_loaders["target"]) - 1
    transfer_loss_value = classifier_loss_value = total_loss_value = 0.0
    best_acc = 0.0
    for i in range(config["num_iterations"]):
        if i % config["test_interval"] == 0:
            base_network.train(False)
            if config['loss']['ly_type'] == "cosine":
                temp_acc = image_classification_test(dset_loaders, \
                    base_network, test_10crop=prep_config["test_10crop"], \
                    gpu=use_gpu)
            elif config['loss']['ly_type'] == "euclidean":
                temp_acc, _ = distance_classification_test(dset_loaders, \
                    base_network, center_criterion.centers.detach(), test_10crop=prep_config["test_10crop"], \
                    gpu=use_gpu)
            else:
                raise ValueError("no test method for cls loss: {}".format(config['loss']['ly_type']))
            
            snapshot_obj = {'step': i, 
                            "base_network": base_network.state_dict(), 
                            'precision': temp_acc, 
                            }
            if config["loss"]["loss_name"] != "laplacian" and config["loss"]["ly_type"] == "euclidean":
                snapshot_obj['center_criterion'] = center_criterion.state_dict()
            if temp_acc > best_acc:
                best_acc = temp_acc
                # save best model
                torch.save(snapshot_obj, 
                           osp.join(config["output_path"], "best_model.pth.tar"))
            log_str = "iter: {:05d}, {} precision: {:.5f}\n".format(i, config['loss']['ly_type'], temp_acc)
            config["out_file"].write(log_str)
            config["out_file"].flush()
            writer.add_scalar("precision", temp_acc, i)

            if early_stop_engine.is_stop_training(temp_acc):
                config["out_file"].write("no improvement after {}, stop training at step {}\n".format(
                    config["early_stop_patience"], i))
                # config["out_file"].write("finish training! \n")
                break

        if (i+1) % config["snapshot_interval"] == 0:
            torch.save(snapshot_obj, 
                       osp.join(config["output_path"], "iter_{:05d}_model.pth.tar".format(i)))
        

        ## train one iter
        base_network.train(True)
        optimizer = lr_scheduler(param_lr, optimizer, i, **schedule_param)
        optimizer.zero_grad()
        if i % len_train_source == 0:
            iter_source = iter(dset_loaders["source"])
        if i % len_train_target == 0:
            iter_target = iter(dset_loaders["target"])
        inputs_source, labels_source = iter_source.next()
        inputs_target, labels_target = iter_target.next()
        if use_gpu:
            inputs_source, inputs_target, labels_source = \
                Variable(inputs_source).cuda(), Variable(inputs_target).cuda(), \
                Variable(labels_source).cuda()
        else:
            inputs_source, inputs_target, labels_source = Variable(inputs_source), \
                Variable(inputs_target), Variable(labels_source)
           
        inputs = torch.cat((inputs_source, inputs_target), dim=0)
        source_batch_size = inputs_source.size(0)

        if config['loss']['ly_type'] == 'cosine':
            features, logits = base_network(inputs)
            source_logits = logits.narrow(0, 0, source_batch_size)
        elif config['loss']['ly_type'] == 'euclidean':
            features, _ = base_network(inputs)
            logits = -1.0 * loss.distance_to_centroids(features, center_criterion.centers.detach())
            source_logits = logits.narrow(0, 0, source_batch_size)

        ad_net.train(True)
        weight_ad = torch.ones(inputs.size(0))
        transfer_loss = transfer_criterion(features, ad_net, gradient_reverse_layer, \
                                           weight_ad, use_gpu)
        ad_out, _ = ad_net(features.detach())
        ad_acc, source_acc_ad, target_acc_ad = domain_cls_accuracy(ad_out)

        # source domain classification task loss
        classifier_loss = class_criterion(source_logits, labels_source)
        # fisher loss on labeled source domain
        fisher_loss, fisher_intra_loss, fisher_inter_loss, center_grad = center_criterion(features.narrow(0, 0, int(inputs.size(0)/2)), labels_source, inter_class=config["loss"]["inter_type"], 
                                                                               intra_loss_weight=loss_params["intra_loss_coef"], inter_loss_weight=loss_params["inter_loss_coef"])
        # entropy minimization loss
        em_loss = loss.EntropyLoss(nn.Softmax(dim=1)(logits))

        # final loss
        total_loss = loss_params["trade_off"] * transfer_loss \
                     + fisher_loss \
                     + loss_params["em_loss_coef"] * em_loss \
                     + classifier_loss

        total_loss.backward()
        if center_grad is not None:
            # clear mmc_loss
            center_criterion.centers.grad.zero_()
            # Manually assign centers gradients other than using autograd
            center_criterion.centers.backward(center_grad)

        optimizer.step()

        if i % config["log_iter"] == 0:
            config['out_file'].write('iter {}: total loss={:0.4f}, transfer loss={:0.4f}, cls loss={:0.4f}, '
                'em loss={:0.4f}, '
                'mmc loss={:0.4f}, intra loss={:0.4f}, inter loss={:0.4f}, '
                'ad acc={:0.4f}, source_acc={:0.4f}, target_acc={:0.4f}\n'.format(
                i, total_loss.data.cpu().float().item(), transfer_loss.data.cpu().float().item(), classifier_loss.data.cpu().float().item(), 
                em_loss.data.cpu().float().item(), 
                fisher_loss.cpu().float().item(), fisher_intra_loss.cpu().float().item(), fisher_inter_loss.cpu().float().item(),
                ad_acc, source_acc_ad, target_acc_ad, 
                ))

            config['out_file'].flush()
            writer.add_scalar("total_loss", total_loss.data.cpu().float().item(), i)
            writer.add_scalar("cls_loss", classifier_loss.data.cpu().float().item(), i)
            writer.add_scalar("transfer_loss", transfer_loss.data.cpu().float().item(), i)
            writer.add_scalar("ad_acc", ad_acc, i)
            writer.add_scalar("d_loss/total", fisher_loss.data.cpu().float().item(), i)
            writer.add_scalar("d_loss/intra", fisher_intra_loss.data.cpu().float().item(), i)
            writer.add_scalar("d_loss/inter", fisher_inter_loss.data.cpu().float().item(), i)
        
    return best_acc
Exemplo n.º 28
0
def train(config):
    ## set pre-process
    prep_dict = {}
    prep_config = config["prep"]
    prep_dict["source"] = prep.image_train( \
                            resize_size=prep_config["resize_size"], \
                            crop_size=prep_config["crop_size"])
    prep_dict["target"] = prep.image_train( \
                            resize_size=prep_config["resize_size"], \
                            crop_size=prep_config["crop_size"])
    if prep_config["test_10crop"]:
        prep_dict["test"] = prep.image_test_10crop( \
                            resize_size=prep_config["resize_size"], \
                            crop_size=prep_config["crop_size"])
    else:
        prep_dict["test"] = prep.image_test( \
                            resize_size=prep_config["resize_size"], \
                            crop_size=prep_config["crop_size"])

    ## set loss
    class_criterion = nn.CrossEntropyLoss()
    transfer_criterion = loss.PADA
    loss_params = config["loss"]

    ## prepare data
    dsets = {}
    dset_loaders = {}
    data_config = config["data"]

    # seperate the source and validation set
    cls_source_list, cls_validation_list = sep.split_set(
        data_config["source"]["list_path"],
        config["network"]["params"]["class_num"])
    source_list = sep.dimension_rd(cls_source_list)

    dsets["source"] = ImageList(source_list, \
                                transform=prep_dict["source"])
    dset_loaders["source"] = util_data.DataLoader(dsets["source"], \
            batch_size=data_config["source"]["batch_size"], \
            shuffle=True, num_workers=4)
    dsets["target"] = ImageList(open(data_config["target"]["list_path"]).readlines(), \
                                transform=prep_dict["target"])
    dset_loaders["target"] = util_data.DataLoader(dsets["target"], \
            batch_size=data_config["target"]["batch_size"], \
            shuffle=True, num_workers=4)

    if prep_config["test_10crop"]:
        for i in range(10):
            dsets["test"+str(i)] = ImageList(open(data_config["test"]["list_path"]).readlines(), \
                                transform=prep_dict["test"]["val"+str(i)])
            dset_loaders["test"+str(i)] = util_data.DataLoader(dsets["test"+str(i)], \
                                batch_size=data_config["test"]["batch_size"], \
                                shuffle=False, num_workers=4)

            dsets["target"+str(i)] = ImageList(open(data_config["target"]["list_path"]).readlines(), \
                                transform=prep_dict["test"]["val"+str(i)])
            dset_loaders["target"+str(i)] = util_data.DataLoader(dsets["target"+str(i)], \
                                batch_size=data_config["test"]["batch_size"], \
                                shuffle=False, num_workers=4)
    else:
        dsets["test"] = ImageList(open(data_config["test"]["list_path"]).readlines(), \
                                transform=prep_dict["test"])
        dset_loaders["test"] = util_data.DataLoader(dsets["test"], \
                                batch_size=data_config["test"]["batch_size"], \
                                shuffle=False, num_workers=4)

        dsets["target_test"] = ImageList(open(data_config["target"]["list_path"]).readlines(), \
                                transform=prep_dict["test"])
        dset_loaders["target_test"] = MyDataLoader(dsets["target_test"], \
                                batch_size=data_config["test"]["batch_size"], \
                                shuffle=False, num_workers=4)

    class_num = config["network"]["params"]["class_num"]

    ## set base network
    net_config = config["network"]
    base_network = net_config["name"](**net_config["params"])

    use_gpu = torch.cuda.is_available()
    if use_gpu:
        base_network = base_network.cuda()

    ## collect parameters
    if net_config["params"]["new_cls"]:
        if net_config["params"]["use_bottleneck"]:
            parameter_list = [{"params":base_network.feature_layers.parameters(), "lr":1}, \
                            {"params":base_network.bottleneck.parameters(), "lr":10}, \
                            {"params":base_network.fc.parameters(), "lr":10}]
        else:
            parameter_list = [{"params":base_network.feature_layers.parameters(), "lr":1}, \
                            {"params":base_network.fc.parameters(), "lr":10}]
    else:
        parameter_list = [{"params": base_network.parameters(), "lr": 1}]

    ## add additional network for some methods
    class_weight = torch.from_numpy(np.array([1.0] * class_num))
    if use_gpu:
        class_weight = class_weight.cuda()
    ad_net = network.AdversarialNetwork(base_network.output_num())
    gradient_reverse_layer = network.AdversarialLayer(
        high_value=config["high"])
    if use_gpu:
        ad_net = ad_net.cuda()
    parameter_list.append({"params": ad_net.parameters(), "lr": 10})

    ## set optimizer
    optimizer_config = config["optimizer"]
    optimizer = optim_dict[optimizer_config["type"]](parameter_list, \
                    **(optimizer_config["optim_params"]))
    param_lr = []
    for param_group in optimizer.param_groups:
        param_lr.append(param_group["lr"])
    schedule_param = optimizer_config["lr_param"]
    lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]]

    ## train
    len_train_source = len(dset_loaders["source"]) - 1
    len_train_target = len(dset_loaders["target"]) - 1
    transfer_loss_value = classifier_loss_value = total_loss_value = 0.0
    best_acc = 0.0
    best_model = 0
    for i in range(config["num_iterations"]):
        if (i + 1) % config["test_interval"] == 0:
            base_network.train(False)
            temp_acc = image_classification_test(dset_loaders, \
                base_network, test_10crop=prep_config["test_10crop"], \
                gpu=use_gpu)
            temp_model = nn.Sequential(base_network)
            if temp_acc > best_acc:
                best_acc = temp_acc
                best_model = temp_model
            log_str = "iter: {:05d}, precision: {:.5f}".format(i, temp_acc)
            config["out_file"].write(log_str)
            config["out_file"].flush()
            print(log_str)
        if (i + 1) % config["snapshot_interval"] == 0:
            torch.save(nn.Sequential(base_network), osp.join(config["output_path"], \
                "iter_{:05d}_model.pth.tar".format(i)))

        # if i % loss_params["update_iter"] == loss_params["update_iter"] - 1:
        #     base_network.train(False)
        #     target_fc8_out = image_classification_predict(dset_loaders, base_network, softmax_param=config["softmax_param"])
        #     class_weight = torch.mean(target_fc8_out, 0)
        #     class_weight = (class_weight / torch.mean(class_weight)).cuda().view(-1)
        #     class_criterion = nn.CrossEntropyLoss(weight = class_weight)

        ## train one iter
        base_network.train(True)
        optimizer = lr_scheduler(param_lr, optimizer, i, **schedule_param)
        optimizer.zero_grad()
        if i % len_train_source == 0:
            iter_source = iter(dset_loaders["source"])
        if i % len_train_target == 0:
            iter_target = iter(dset_loaders["target"])
        inputs_source, labels_source = iter_source.next()
        inputs_target, labels_target = iter_target.next()
        if use_gpu:
            inputs_source, inputs_target, labels_source = \
                Variable(inputs_source).cuda(), Variable(inputs_target).cuda(), \
                Variable(labels_source).cuda()
        else:
            inputs_source, inputs_target, labels_source = Variable(inputs_source), \
                Variable(inputs_target), Variable(labels_source)

        inputs = torch.cat((inputs_source, inputs_target), dim=0)
        features, outputs = base_network(inputs)

        # #
        # if i % 100 == 0:
        #     check = dev.get_label_list(open(data_config["source"]["list_path"]).readlines(),
        #                                base_network,
        #                                prep_config["resize_size"],
        #                                prep_config["crop_size"],
        #                                data_config["target"]["batch_size"],
        #                                use_gpu)
        #     f = open("Class_result.txt", "a+")
        #     f.close()
        #     for cls in range(class_num):
        #         count = 0
        #         for j in check:
        #             if int(j.split(" ")[1].replace("\n", "")) == cls:
        #                 count = count + 1
        #         f = open("Source_result.txt", "a+")
        #         f.write("Source_Class: " + str(cls) + "\n" + "Number of images: " + str(count) + "\n")
        #         f.close()
        #
        #     check = dev.get_label_list(open(data_config["target"]["list_path"]).readlines(),
        #                                base_network,
        #                                prep_config["resize_size"],
        #                                prep_config["crop_size"],
        #                                data_config["target"]["batch_size"],
        #                                use_gpu)
        #     f = open("Class_result.txt", "a+")
        #     f.write("Iteration: " + str(i) + "\n")
        #     f.close()
        #     for cls in range(class_num):
        #         count = 0
        #         for j in check:
        #             if int(j.split(" ")[1].replace("\n", "")) == cls:
        #                 count = count + 1
        #         f = open("Class_result.txt", "a+")
        #         f.write("Target_Class: " + str(cls) + "\n" + "Number of images: " + str(count) + "\n")
        #         f.close()

        #
        # #
        # print("Training test:")
        # print(features)
        # print(features.shape)
        # print(outputs)
        # print(outputs.shape)

        softmax_out = nn.Softmax(dim=1)(outputs).detach()
        ad_net.train(True)
        weight_ad = torch.ones(inputs.size(0))
        # label_numpy = labels_source.data.cpu().numpy()
        # for j in range(int(inputs.size(0) / 2)):
        #     weight_ad[j] = class_weight[int(label_numpy[j])]
        # weight_ad = weight_ad / torch.max(weight_ad[0:int(inputs.size(0)/2)])
        # for j in range(int(inputs.size(0) / 2), inputs.size(0)):
        #     weight_ad[j] = 1.0
        transfer_loss = transfer_criterion(features, ad_net, gradient_reverse_layer, \
                                           weight_ad, use_gpu)

        classifier_loss = class_criterion(
            outputs.narrow(0, 0, int(inputs.size(0) / 2)), labels_source)

        total_loss = loss_params["trade_off"] * transfer_loss + classifier_loss
        total_loss.backward()
        optimizer.step()
    cv_loss = dev.cross_validation_loss(
        base_network, base_network, cls_source_list,
        data_config["target"]["list_path"], cls_validation_list, class_num,
        prep_config["resize_size"], prep_config["crop_size"],
        data_config["target"]["batch_size"], use_gpu)
    print(cv_loss)
Exemplo n.º 29
0
def train(config):
    ## set pre-process
    prep_dict = {}
    prep_config = config["prep"]
    prep_dict["source"] = prep.image_train(**config["prep"]['params'])
    prep_dict["target"] = prep.image_train(**config["prep"]['params'])
    if prep_config["test_10crop"]:
        prep_dict["test"] = prep.image_test_10crop(**config["prep"]['params'])
    else:
        prep_dict["test"] = prep.image_test(**config["prep"]['params'])

    ## prepare data
    dsets = {}
    dset_loaders = {}
    data_config = config["data"]
    train_bs = data_config["source"]["batch_size"]
    test_bs = data_config["test"]["batch_size"]
    dsets["source"] = ImageList(open(data_config["source"]["list_path"]).readlines(), \
                                transform=prep_dict["source"])
    dset_loaders["source"] = DataLoader(dsets["source"], batch_size=train_bs, \
            shuffle=True, num_workers=4, drop_last=True)
    dsets["target"] = ImageList(open(data_config["target"]["list_path"]).readlines(), \
                                transform=prep_dict["target"])
    dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, \
            shuffle=True, num_workers=4, drop_last=True)

    #     if prep_config["test_10crop"]:
    #         for i in range(10):
    #             dsets["test"] = [ImageList(open(data_config["test"]["list_path"]).readlines(), \
    #                                 transform=prep_dict["test"][i]) for i in range(10)]
    #             dset_loaders["test"] = [DataLoader(dset, batch_size=test_bs, \
    #                                 shuffle=False, num_workers=4) for dset in dsets['test']]
    #     else:
    #         dsets["test"] = ImageList(open(data_config["test"]["list_path"]).readlines(), \
    #                                 transform=prep_dict["test"])
    #         dset_loaders["test"] = DataLoader(dsets["test"], batch_size=test_bs, \
    #                                 shuffle=False, num_workers=4)

    class_num = config["network"]["params"]["class_num"]

    ## set base network
    net_config = config["network"]
    base_network = net_config["name"](**net_config["params"])
    base_network = base_network.cuda()

    ## add additional network for some methods
    if config["loss"]["random"]:
        random_layer = network.RandomLayer(
            [base_network.output_num(), class_num],
            config["loss"]["random_dim"])
        ad_net = network.AdversarialNetwork(config["loss"]["random_dim"], 1024)
    else:
        random_layer = None
        ad_net = network.AdversarialNetwork(
            base_network.output_num() * class_num, 1024)
    if config["loss"]["random"]:
        random_layer.cuda()
    ad_net = ad_net.cuda()
    parameter_list = base_network.get_parameters() + ad_net.get_parameters()

    ## set optimizer
    optimizer_config = config["optimizer"]
    optimizer = optimizer_config["type"](parameter_list, \
                    **(optimizer_config["optim_params"]))
    param_lr = []
    for param_group in optimizer.param_groups:
        param_lr.append(param_group["lr"])
    schedule_param = optimizer_config["lr_param"]
    lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]]

    gpus = config['gpu'].split(',')
    if len(gpus) > 1:
        ad_net = nn.DataParallel(ad_net, device_ids=[int(i) for i in gpus])
        base_network = nn.DataParallel(base_network,
                                       device_ids=[int(i) for i in gpus])

    ## train
    len_train_source = len(dset_loaders["source"])
    len_train_target = len(dset_loaders["target"])
    transfer_loss_value = classifier_loss_value = total_loss_value = 0.0
    best_acc = 0.0
    for i in range(config["num_iterations"]):
        #         if i % config["test_interval"] == config["test_interval"] - 1:
        #             base_network.train(False)
        #             temp_acc = image_classification_test(dset_loaders, \
        #                 base_network, test_10crop=prep_config["test_10crop"])
        #             temp_model = nn.Sequential(base_network)
        #             if temp_acc > best_acc:
        #                 best_acc = temp_acc
        #                 best_model = temp_model
        #             log_str = "iter: {:05d}, precision: {:.5f}".format(i, temp_acc)
        #             config["out_file"].write(log_str+"\n")
        #             config["out_file"].flush()
        #             print(log_str)
        if i % config["snapshot_interval"] == 0:
            torch.save(nn.Sequential(base_network), osp.join(config["output_path"], \
                "iter_{:05d}_model.pth.tar".format(i)))

        loss_params = config["loss"]
        ## train one iter
        base_network.train(True)
        ad_net.train(True)
        optimizer = lr_scheduler(optimizer, i, **schedule_param)
        optimizer.zero_grad()
        if i % len_train_source == 0:
            iter_source = iter(dset_loaders["source"])
        if i % len_train_target == 0:
            iter_target = iter(dset_loaders["target"])
        inputs_source, labels_source = iter_source.next()
        inputs_target, labels_target = iter_target.next()
        inputs_source, inputs_target, labels_source = inputs_source.cuda(
        ), inputs_target.cuda(), labels_source.cuda()
        features_source, outputs_source = base_network(inputs_source)
        features_target, outputs_target = base_network(inputs_target)
        features = torch.cat((features_source, features_target), dim=0)
        outputs = torch.cat((outputs_source, outputs_target), dim=0)
        softmax_out = nn.Softmax(dim=1)(outputs)
        if config['method'] == 'CDAN+E':
            entropy = loss.Entropy(softmax_out)
            transfer_loss = loss.CDAN([features, softmax_out], ad_net, entropy,
                                      network.calc_coeff(i), random_layer)
        elif config['method'] == 'CDAN':
            transfer_loss = loss.CDAN([features, softmax_out], ad_net, None,
                                      None, random_layer)
        elif config['method'] == 'DANN':
            transfer_loss = loss.DANN(features, ad_net)
        else:
            raise ValueError('Method cannot be recognized.')
        classifier_loss = nn.CrossEntropyLoss()(outputs_source, labels_source)
        if i % 10 == 0:
            print('iter: ', i, 'classifier_loss: ', classifier_loss.data,
                  'transfer_loss: ', transfer_loss.data)
        total_loss = loss_params["trade_off"] * transfer_loss + classifier_loss
        total_loss.backward()
        optimizer.step()
    torch.save(best_model, osp.join(config["output_path"],
                                    "best_model.pth.tar"))
    return best_acc
Exemplo n.º 30
0
def main():
    global args, best_prec1
    args = parser.parse_args()
    check_rootfolders()

    categories, args.train_source_list, args.train_target_list, args.val_list, args.root_path_source, args.root_path_target, prefix = datasets_video.return_dataset(
        args.dataset, args.modality)
    num_class = len(categories)

    args.store_name = '_'.join([
        'TRN', args.dataset, args.modality, args.arch, args.consensus_type,
        'segment%d' % args.num_segments
    ])
    print('storing name: ' + args.store_name)

    model = TSN(num_class,
                args.num_segments,
                args.modality,
                base_model=args.arch,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                img_feature_dim=args.img_feature_dim,
                partial_bn=not args.no_partialbn)
    ad_net = network.AdversarialNetwork(1024, 2048)
    ad_net = ad_net.cuda()

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    policies = model.get_optim_policies()
    policies_ad = ad_net.get_parameters()
    train_augmentation = model.get_augmentation()

    model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()

    if args.resume:
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print(("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    cudnn.benchmark = True

    # Data loading code
    if args.modality != 'RGBDiff':
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    if args.modality == 'RGB':
        data_length = 1
    elif args.modality in ['Flow', 'RGBDiff']:
        data_length = 5

    source_loader = torch.utils.data.DataLoader(TSNDataSet(
        args.root_path_source,
        args.train_source_list,
        num_segments=args.num_segments,
        new_length=data_length,
        modality=args.modality,
        image_tmpl=prefix,
        transform=torchvision.transforms.Compose([
            train_augmentation,
            Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
            ToTorchFormatTensor(
                div=(args.arch not in ['BNInception', 'InceptionV3'])),
            normalize,
        ])),
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                drop_last=True,
                                                num_workers=args.workers,
                                                pin_memory=True)

    target_loader = torch.utils.data.DataLoader(TSNDataSet(
        args.root_path_target,
        args.train_target_list,
        num_segments=args.num_segments,
        new_length=data_length,
        modality=args.modality,
        image_tmpl=prefix,
        transform=torchvision.transforms.Compose([
            train_augmentation,
            Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
            ToTorchFormatTensor(
                div=(args.arch not in ['BNInception', 'InceptionV3'])),
            normalize,
        ])),
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                drop_last=True,
                                                num_workers=args.workers,
                                                pin_memory=True)

    val_loader = torch.utils.data.DataLoader(TSNDataSet(
        args.root_path_target,
        args.val_list,
        num_segments=args.num_segments,
        new_length=data_length,
        modality=args.modality,
        image_tmpl=prefix,
        random_shift=False,
        transform=torchvision.transforms.Compose([
            GroupScale(int(scale_size)),
            GroupCenterCrop(crop_size),
            Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
            ToTorchFormatTensor(
                div=(args.arch not in ['BNInception', 'InceptionV3'])),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             drop_last=True,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # define loss function (criterion) and optimizer
    if args.loss_type == 'nll':
        criterion = torch.nn.CrossEntropyLoss().cuda()
    else:
        raise ValueError("Unknown loss type")

    for group in policies:
        print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
            group['name'], len(group['params']), group['lr_mult'],
            group['decay_mult'])))
    parameter_list = policies + policies_ad
    optimizer = torch.optim.SGD(parameter_list,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    if args.evaluate:
        validate(val_loader, model, criterion, num_class, 0)
        return

    log_training = open(
        os.path.join(args.root_log, '%s.csv' % args.store_name), 'w')
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args.lr_steps)

        # train for one epoch
        train(source_loader, target_loader, model, ad_net, criterion,
              optimizer, epoch, log_training, num_class)

        # evaluate on validation set
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            #prec1 = validate(val_loader, model, criterion, (epoch + 1) * len(target_loader), log_training)
            prec1 = validate(val_loader, model, criterion, num_class,
                             epoch + 1, log_training)
            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                }, is_best)