Exemplo n.º 1
0
def main():
    batch_size = 32
    test_batch_size = 32

    p = os.path.expanduser("./image-clef/p")
    c = os.path.expanduser("./image-clef/c")
    i = os.path.expanduser("./image-clef/i")

    epochs = 200
    lr = 0.0001
    device = torch.device("cuda")

    train_loader_source = DA_datasets.imageclef_train_loader(p, batch_size, 0)
    train_loader_target = DA_datasets.imageclef_train_loader(c, batch_size, 0)
    testloader_1_target = DA_datasets.imageclef_test_loader(
        c, test_batch_size, 0)

    model_dan = DAN_model.DANNet_ResNet(ResNet.resnet50, True).to(device)
    #model_dan = DAN_model.DANNetVGG16(models.vgg16, True).to(device)

    optimizer = torch.optim.SGD(model_dan.parameters(),
                                momentum=0.9,
                                lr=lr,
                                weight_decay=5e-4)
    dan_train(epochs,
              lr,
              model_dan,
              train_loader_source,
              device,
              train_loader_target,
              testloader_1_target,
              optimizer,
              scheduler=None,
              is_debug=False)
Exemplo n.º 2
0
def main():
    batch_size = 64
    test_batch_size = 64
    lr = 0.1
    momentum = 0.9
    epochs = 100
    epoch_step = 30
    weight_decay = 1e-4
    teacher_pretrained_path = "{}/dan_resnet50_amazon_2_webcam.pth".format(save_dir)
    student_pretrained = False
    device = torch.device("cuda")

    webcam = os.path.expanduser("~/datasets/webcam/images")
    amazon = os.path.expanduser("~/datasets/amazon/images")
    dslr = os.path.expanduser("~/datasets/dslr/images")

    train_loader_source = DA_datasets.office_loader(amazon, batch_size, 0)
    train_loader_target = DA_datasets.office_loader(webcam, batch_size, 0)
    testloader_target = DA_datasets.office_test_loader(webcam, test_batch_size, 0)

    logger = VisdomLogger(port=10999)
    logger = LoggerForSacred(logger)

    teacher_model = DAN_model.DANNet_ResNet(ResNet.resnet50, True)
    student_model = DAN_model.DANNet_ResNet(ResNet.resnet34, student_pretrained)

    if teacher_pretrained_path != "":
        teacher_model.load_state_dict(torch.load(teacher_pretrained_path))

    if torch.cuda.device_count() > 1:
        teacher_model = torch.nn.DataParallel(teacher_model).to(device)
        student_model = torch.nn.DataParallel(student_model).to(device)

    distiller_model = od_distiller.Distiller_DAN(teacher_model, student_model)

    if torch.cuda.device_count() > 1:
        distiller_model = torch.nn.DataParallel(distiller_model).to(device)

    if torch.cuda.device_count() > 1:
        optimizer = torch.optim.SGD(list(student_model.parameters()) + list(distiller_model.module.Connectors.parameters()),
                                    lr, momentum=momentum, weight_decay=weight_decay, nesterov=True)
    else:
        optimizer = torch.optim.SGD(list(student_model.parameters()) + list(distiller_model.Connectors.parameters()),
                                    lr, momentum=momentum, weight_decay=weight_decay, nesterov=True)

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, epoch_step)

    od_kd_without_label(epochs, teacher_model, student_model, distiller_model, optimizer, train_loader_target,
                        testloader_target, device, logger=logger, scheduler=scheduler)
Exemplo n.º 3
0
def main():
    batch_size = 32
    test_batch_size = 32
    lr = 0.0005
    momentum = 0.9
    epochs = 200
    epoch_step = 30
    weight_decay = 5e-4
    teacher_pretrained_path = "./da_resnet50_p_i.pth"
    student_pretrained = False
    device = torch.device("cuda")

    p = os.path.expanduser("./image-clef/p")
    c = os.path.expanduser("./image-clef/c")
    i = os.path.expanduser("./image-clef/i")

    train_loader_source = DA_datasets.imageclef_train_loader(p, batch_size, 0)
    train_loader_target = DA_datasets.imageclef_train_loader(i, batch_size, 0)
    testloader_target = DA_datasets.imageclef_test_loader(
        i, test_batch_size, 0)

    teacher_model = DAN_model.DANNet_ResNet(ResNet.resnet50, False)
    student_model = DAN_model.DANNet_ResNet(ResNet.resnet34,
                                            student_pretrained)
    teacher_model = teacher_model.to(device)
    student_model = student_model.to(device)
    if teacher_pretrained_path != "":
        teacher_model.load_state_dict(
            torch.load(teacher_pretrained_path)['student_model'])

    if torch.cuda.device_count() > 1:
        teacher_model = torch.nn.DataParallel(teacher_model).to(device)
        student_model = torch.nn.DataParallel(student_model).to(device)

    distiller_model = distiller.Distiller(teacher_model,
                                          student_model).to(device)

    if torch.cuda.device_count() > 1:
        distiller_model = torch.nn.DataParallel(distiller_model).to(device)

    optimizer = torch.optim.SGD(list(distiller_model.s_net.parameters()) +
                                list(distiller_model.Connectors.parameters()),
                                lr,
                                momentum=momentum,
                                weight_decay=weight_decay)

    od_kd_without_label(epochs, distiller_model, optimizer,
                        train_loader_target, testloader_target, device)
Exemplo n.º 4
0
def train_normal():
    batch_size = 32
    test_batch_size = 32
    lr = 0.001
    momentum = 0.9
    epochs = 200
    weight_decay = 5e-4
    device = torch.device("cuda")

    p = os.path.expanduser("./image-clef/p")
    c = os.path.expanduser("./image-clef/c")
    i = os.path.expanduser("./image-clef/i")

    train_dataset = DA_datasets.imageclef_train_loader(p, batch_size, 0)
    test_dataset = DA_datasets.imageclef_test_loader(p, test_batch_size, 0)
    model = DAN_model.DANNet_ResNet(ResNet.resnet50, True)
    model = model.to(device)
    optimizer = torch.optim.SGD(model.parameters(),
                                lr,
                                momentum=momentum,
                                weight_decay=weight_decay)
    criterion = nn.CrossEntropyLoss()
    best_acc = 0.0
    for epoch in range(epochs):
        running_loss = 0.0
        for i, data in enumerate(dataset, 0):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        acc = eval(model, device, test_dataset)
        if (acc > best_acc):
            best_acc = acc
            torch.save(model, 'teacher_model_p.pth')
        print(f'epoch : {epoch}, acc : {acc}')
Exemplo n.º 5
0
def main():


    batch_size = 16
    test_batch_size = 32

    p = os.path.expanduser("./image-clef/p")
    c = os.path.expanduser("./image-clef/c")
    i = os.path.expanduser("./image-clef/i")
    is_debug = False

    epochs = 400
    init_lr_da = 0.0001
    init_lr_kd = 0.0005
    momentum = 0.9
    weight_decay = 5e-4
    device = torch.device("cuda")
    alpha = 0.5
    gamma = 0.5
    init_beta = 0.1
    end_beta = 0.9

    student_pretrained = True

    source_dataloader, target_dataloader, target_testloader = DA_datasets.get_source_target_loader("ImageClef",
                                                                                                   p,
                                                                                                   c,
                                                                                                   batch_size, 0)
    if torch.cuda.device_count() > 1:
        teacher_model = nn.DataParallel(DANN_GRL.DANN_GRL_Resnet(ResNet.resnet101, True, source_dataloader.dataset.num_classes)).to(device)
        student_model = nn.DataParallel(DANN_GRL.DANN_GRL_Resnet(ResNet.resnet34, student_pretrained, source_dataloader.dataset.num_classes)).to(device)
    else:
        teacher_model = DANN_GRL.DANN_GRL_Resnet(ResNet.resnet101, True, source_dataloader.dataset.num_classes).to(device)
        student_model = DANN_GRL.DANN_GRL_Resnet(ResNet.resnet34, student_pretrained, source_dataloader.dataset.num_classes).to(device)

    distil = distiller.Distiller(teacher_model, student_model)
    distil = distil.to(device)
    criterion = nn.CrossEntropyLoss()

    growth_rate = torch.log(torch.FloatTensor([end_beta / init_beta])) / torch.FloatTensor([epochs])


    optimizer_da = torch.optim.SGD(distil.t_net.parameters(), init_lr_da,
                                momentum=momentum, weight_decay=weight_decay)

    optimizer_kd = torch.optim.SGD(list(distil.s_net.parameters()) + list(distil.Connectors.parameters()), init_lr_kd,
                                momentum=momentum, weight_decay=weight_decay)

    mmd_hinton_alt(init_lr_da, init_lr_kd, device, epochs, alpha, gamma, growth_rate, init_beta, source_dataloader,
               target_dataloader, target_testloader, optimizer_da, optimizer_kd, distil, criterion, is_scheduler_kd=False, is_scheduler_da=True, is_debug=False, best_teacher_acc = 0)
Exemplo n.º 6
0
def exp_kd_da_grl_alt(init_lr_da, init_lr_kd, momentum, weight_decay, device,
                      epochs, batch_size, init_beta, end_beta, T, alpha, gamma,
                      batch_norm, is_cst, resize_digits, is_scheduler_da,
                      is_scheduler_kd, scheduler_kd_fn, scheduler_kd_steps,
                      scheduler_kd_gamma, dataset_name, source_dataset_path,
                      target_dataset_paths, dan_model_func, teacher_net_func,
                      dan_model_func_student, student_net_func,
                      student_pretrained, is_debug, _run):

    source_dataloader, targets_dataloader, targets_testloader = DA_datasets.get_source_m_target_loader(
        dataset_name,
        source_dataset_path,
        target_dataset_paths,
        batch_size,
        0,
        drop_last=True,
        resize=resize_digits)
    teacher_models = []
    for p in target_dataset_paths:
        teacher_models.append(
            dan_model_func(teacher_net_func, True,
                           source_dataloader.dataset.num_classes).to(device))

    if student_net_func == Lenet.LeNet or student_net_func == Lenet.MTDA_ITA_classifier:
        student_model = dan_model_func_student(
            student_net_func,
            student_pretrained,
            source_dataloader.dataset.num_classes,
            input_size=resize_digits).to(device)
    else:
        student_model = dan_model_func_student(
            student_net_func, student_pretrained,
            source_dataloader.dataset.num_classes).to(device)

    if torch.cuda.device_count() > 1:
        for i, tm in enumerate(teacher_models):
            teacher_models[i] = nn.DataParallel(tm).to(device)
        student_model = nn.DataParallel(student_model).to(device)

    logger = LoggerForSacred(None, ex)

    growth_rate = torch.zeros(1)
    if init_beta != 0.0:
        growth_rate = torch.log(torch.FloatTensor(
            [end_beta / init_beta])) / torch.FloatTensor([epochs])

    optimizer_das = []
    for tm in teacher_models:
        optimizer_das.append(
            torch.optim.SGD(tm.parameters(),
                            init_lr_da,
                            momentum=momentum,
                            weight_decay=weight_decay))

    optimizer_kd = torch.optim.SGD(student_model.parameters(),
                                   init_lr_kd,
                                   momentum=momentum,
                                   weight_decay=weight_decay)

    scheduler_kd = None
    if scheduler_kd_fn is not None:
        scheduler_kd = scheduler_kd_fn(optimizer_kd, scheduler_kd_steps,
                                       scheduler_kd_gamma)

    if dataset_name != "Digits" and dataset_name != "Digits_no_split":
        source_name_1 = get_sub_dataset_name(dataset_name, source_dataset_path)
    else:
        source_name_1 = source_dataset_path

    save_name = "best_{}_{}_and_2{}_kd_da_alt.p".format(
        _run._id, source_name_1, "the_rest")

    best_student_acc = kd_da_grl_alt_multi_target_cst_fac.grl_multi_target_hinton_alt(
        init_lr_da,
        init_lr_kd,
        device,
        epochs,
        T,
        alpha,
        gamma,
        growth_rate,
        init_beta,
        source_dataloader,
        targets_dataloader,
        targets_testloader,
        optimizer_das,
        optimizer_kd,
        teacher_models,
        student_model,
        logger=logger,
        is_scheduler_da=is_scheduler_da,
        is_scheduler_kd=is_scheduler_kd,
        scheduler_kd=None,
        scheduler_da=None,
        is_debug=is_debug,
        run=_run,
        save_name=save_name,
        batch_norm=batch_norm,
        is_cst=is_cst)

    conf_path = "{}/{}_{}.json".format("all_confs", _run._id, best_student_acc)
    with open(conf_path, 'w') as cf:
        json.dump(_run.config, cf, default=custom_json_dumper)

    #send_email(_run, best_student_acc, os.uname()[1])

    return best_student_acc
Exemplo n.º 7
0
def main():

    a = os.path.expanduser('~/datasets/amazon/images')
    w = os.path.expanduser('~/datasets/webcam/images')
    d = os.path.expanduser('~/datasets/dslr/images')

    Ar = os.path.expanduser('~/datasets/OfficeHome/Art')
    Cl = os.path.expanduser('~/datasets/OfficeHome/Clipart')
    Pr = os.path.expanduser('~/datasets/OfficeHome/Product')
    Rw = os.path.expanduser('~/datasets/OfficeHome/RealWorld')

    i = os.path.expanduser('~/datasets/image-clef/i')
    p = os.path.expanduser('~/datasets/image-clef/p')
    c = os.path.expanduser('~/datasets/image-clef/c')
    is_debug = False

    batch_size = 16
    device = torch.device("cuda")
    student_net_func = AlexNet.alexnet
    dan_model_func_student = DANN_GRL.DANN_GRL_Alexnet
    dataset_name = "Office31"
    source_dataset_path = w
    target_dataset_paths = [a, d]
    resize_digits = 28
    is_btda = False
    finished_model_path = "best_48_webcam_and_2the_rest_kd_da_alt.p"

    source_dataloader, targets_dataloader, targets_testloader = DA_datasets.get_source_m_target_loader(
        dataset_name,
        source_dataset_path,
        target_dataset_paths,
        batch_size,
        0,
        drop_last=True,
        resize=resize_digits)

    begin_pretrained = True

    if student_net_func == LeNet.LeNet:
        begin_model = dan_model_func_student(
            student_net_func,
            begin_pretrained,
            source_dataloader.dataset.num_classes,
            input_size=resize_digits).to(device)
    else:
        begin_model = dan_model_func_student(
            student_net_func, begin_pretrained,
            source_dataloader.dataset.num_classes).to(device)
    logger = LoggerForSacred(None, None, True)
    if is_btda:
        finished_model = BTDA_Alexnet.Alex_Model_Office31()
        finished_model.load_state_dict(torch.load(finished_model_path))
        finished_model = finished_model.to(device)
        finished_model.eval()
    else:
        if finished_model_path.endswith('p'):
            finished_model = torch.load(finished_model_path).to(device)
        else:
            finished_model = dan_model_func_student(
                student_net_func, begin_pretrained,
                source_dataloader.dataset.num_classes)
            finished_model.load_state_dict(torch.load(finished_model_path))
            finished_model = finished_model.to(device)
        finished_model.eval()

    s_name = get_sub_dataset_name(dataset_name, source_dataset_path)
    for i, tloader in enumerate(targets_testloader):
        acc = eval(begin_model, device, tloader, False)
        p_name = get_sub_dataset_name(dataset_name, target_dataset_paths[i])
        print("b_model_from{}_2_{}_acc:{}".format(s_name, p_name, acc))
Exemplo n.º 8
0
    return model_dan, optimizer, best_acc


if __name__ == "__main__":
    batch_size = 32
    test_batch_size = 32

    #train_path = "/home/ens/AN88740/dataset/webcam/images"
    #test_path = "/home/ens/AN88740/dataset/amazon/images"

    webcam = os.path.expanduser("~/datasets/webcam/images")
    amazon = os.path.expanduser("~/datasets/amazon/images")
    dslr = os.path.expanduser("~/datasets/dslr/images")

    epochs = 200
    lr = 0.01
    device = torch.device("cuda")

    train_loader_source = DA_datasets.office_loader(webcam, batch_size, 0)
    train_loader_target = DA_datasets.office_loader(amazon, batch_size, 0)
    testloader_1_target = DA_datasets.office_test_loader(amazon, test_batch_size, 0)

    logger = VisdomLogger(port=9000)
    logger = LoggerForSacred(logger)

    #model_dan = DAN_model.DANNet_ResNet(ResNet.resnet50, True).to(device)
    model_dan = DAN_model.DANNetVGG16(models.vgg16, True).to(device)

    optimizer = torch.optim.SGD(model_dan.parameters(), momentum=0.9, lr=lr, weight_decay=5e-4)
    dann_grl_train(epochs, lr, model_dan, train_loader_source, device, train_loader_target, testloader_1_target, optimizer, logger=logger,
                   logger_id="", scheduler=None, is_debug=False)
Exemplo n.º 9
0
def main():
	batch_size = 32
	test_batch_size = 16

	a = os.path.expanduser('../../datasets/Art')
	c = os.path.expanduser('../../datasets/Clipart')
	p = os.path.expanduser('../../datasets/Product')
	r = os.path.expanduser('../../datasets/RealWorld')
	dataset_name = 'Office31'
	source_dataset_path = a
	target_dataset_path_1 = c
	target_dataset_path_2 = p
	target_dataset_path_3 = r
	
	batch_norm = True
	epochs = 400
	init_lr_da = 0.0001
	init_lr_kd = 0.001
	momentum = 0.9
	weight_decay = 5e-4
	device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
	alpha = 0.5
	gamma = 0.5
	init_beta = 0.1
	end_beta = 0.8
	is_scheduler_da = True
	is_scheduler_kd = False

	source_dataloader, targets_dataloader, targets_testloader = DA_datasets.get_source_m_target_loader(dataset_name,
													source_dataset_path,
													[target_dataset_path_1, target_dataset_path_2,target_dataset_path_3],
													batch_size, 0, drop_last=True)

	teacher_model_1 = DANN_GRL.DANN_GRL_Resnet(ResNet.resnet152, True, source_dataloader.dataset.num_classes).to(device)
	teacher_model_2 = DANN_GRL.DANN_GRL_Resnet(ResNet.resnet152, True, source_dataloader.dataset.num_classes).to(device)
	teacher_model_3 = DANN_GRL.DANN_GRL_Resnet(ResNet.resnet152, True, source_dataloader.dataset.num_classes).to(device)
	student_model = DANN_GRL.DANN_GRL_Resnet(ResNet.resnet50, True, source_dataloader.dataset.num_classes).to(device)

	if torch.cuda.device_count() > 1:
		teacher_model_1 = nn.DataParallel(teacher_model_1).to(device)
		teacher_model_2 = nn.DataParallel(teacher_model_2).to(device)
		teacher_model_3 = nn.DataParallel(teacher_model_3).to(device)
		student_model = nn.DataParallel(student_model).to(device)

	distil_1 = distiller.Distiller(teacher_model_1, student_model)
	distil_1 = distil_1.to(device)
	distil_2 = distiller.Distiller(teacher_model_2, student_model)
	distil_2 = distil_2.to(device)
	distil_3 = distiller.Distiller(teacher_model_3, student_model)
	distil_3 = distil_3.to(device)

	distils = [distil_1, distil_2, distil_3]

	criterion = nn.CrossEntropyLoss()
	growth_rate = torch.zeros(1)
	if init_beta != 0.0:
		growth_rate = torch.log(torch.FloatTensor([end_beta / init_beta])) / torch.FloatTensor([epochs])

	optimizer_da_1 = torch.optim.SGD(list(distil_1.s_net.parameters()) + list(distil_1.t_net.parameters()), init_lr_da,
					momentum=momentum, weight_decay=weight_decay)

	optimizer_da_2 = torch.optim.SGD(list(distil_2.s_net.parameters()) + list(distil_2.t_net.parameters()), init_lr_da,
					momentum=momentum, weight_decay=weight_decay)

	optimizer_da_3 = torch.optim.SGD(list(distil_3.s_net.parameters()) + list(distil_3.t_net.parameters()), init_lr_da,
					momentum=momentum, weight_decay=weight_decay)

	optimizer_kd_1 = torch.optim.SGD(list(distil_1.s_net.parameters()) + list(distil_1.Connectors.parameters()), init_lr_kd,
					momentum=momentum, weight_decay=weight_decay)

	optimizer_kd_2 = torch.optim.SGD(list(distil_2.s_net.parameters()) + list(distil_2.Connectors.parameters()), init_lr_kd,
					momentum=momentum, weight_decay=weight_decay)

	optimizer_kd_3 = torch.optim.SGD(list(distil_3.s_net.parameters()) + list(distil_3.Connectors.parameters()), init_lr_kd,
					momentum=momentum, weight_decay=weight_decay)
	
	optimizer_kds = [optimizer_kd_1, optimizer_kd_2, optimizer_kd_3]
	optimizer_das = [optimizer_da_1, optimizer_da_2, optimizer_da_3]


	od_mmd_train(init_lr_da, init_lr_kd, epochs, growth_rate, alpha, gamma, init_beta, distils, source_dataloader, targets_dataloader, targets_testloader,
		optimizer_das, optimizer_kds, criterion, device, batch_norm, is_scheduler_da = is_scheduler_da, is_scheduler_kd = is_scheduler_kd, scheduler_da=None, scheduler_kd=None)
Exemplo n.º 10
0
def main():
    batch_size = 32
    test_batch_size = 32

    webcam = os.path.expanduser("~/datasets/webcam/images")
    amazon = os.path.expanduser("~/datasets/amazon/images")
    dslr = os.path.expanduser("~/datasets/dslr/images")
    is_debug = False

    epochs = 400
    init_lr_da = 0.001
    init_lr_kd = 0.001
    momentum = 0.9
    weight_decay = 5e-4
    device = torch.device("cuda")
    T = 20
    alpha = 0.3
    init_beta = 0.1
    end_beta = 0.9

    student_pretrained = True

    if torch.cuda.device_count() > 1:
        teacher_model = nn.DataParallel(
            DAN_model.DANNet_ResNet(ResNet.resnet50, True)).to(device)
        student_model = nn.DataParallel(
            DAN_model.DANNet_ResNet(ResNet.resnet34,
                                    student_pretrained)).to(device)
    else:
        teacher_model = DAN_model.DANNet_ResNet(ResNet.resnet50,
                                                True).to(device)
        student_model = DAN_model.DANNet_ResNet(ResNet.resnet34,
                                                student_pretrained).to(device)

    growth_rate = torch.log(torch.FloatTensor(
        [end_beta / init_beta])) / torch.FloatTensor([epochs])

    optimizer_da = torch.optim.SGD(list(teacher_model.parameters()) +
                                   list(student_model.parameters()),
                                   init_lr_da,
                                   momentum=momentum,
                                   weight_decay=weight_decay)

    optimizer_kd = torch.optim.SGD(list(teacher_model.parameters()) +
                                   list(student_model.parameters()),
                                   init_lr_kd,
                                   momentum=momentum,
                                   weight_decay=weight_decay)

    source_dataloader = DA_datasets.office_loader(amazon, batch_size, 1)
    target_dataloader = DA_datasets.office_loader(webcam, batch_size, 1)
    target_testloader = DA_datasets.office_test_loader(webcam, test_batch_size,
                                                       1)

    logger = LoggerForSacred(None, None, True)

    grl_multi_target_hinton_alt(init_lr_da,
                                device,
                                epochs,
                                T,
                                alpha,
                                growth_rate,
                                init_beta,
                                source_dataloader,
                                target_dataloader,
                                target_testloader,
                                optimizer_da,
                                optimizer_kd,
                                teacher_model,
                                student_model,
                                logger=logger,
                                scheduler=None,
                                is_debug=False)
Exemplo n.º 11
0
def main():

    a = os.path.expanduser('./datasets/amazon/images')
    w = os.path.expanduser('./datasets/webcam/images')
    d = os.path.expanduser('./datasets/dslr/images')

    dataset_name = 'Office31'

    source_dataset_path = a
    target_dataset_path_1 = w

    init_beta = 0.1
    end_beta = 0.6
    init_lr_da = 0.001
    init_lr_kd = 0.01
    momentum = 0.9
    T = 20
    batch_size = 32
    alpha = 0.5
    gamma = 0.5
    epochs = 400
    scheduler_kd_fn = None
    batch_norm = True
    device = 'cuda'
    weight_decay = 5e-4
    is_scheduler_da = True
    is_scheduler_kd = True

    source_dataloader, targets_dataloader, targets_testloader = DA_datasets.get_source_target_loader(
        dataset_name,
        source_dataset_path,
        target_dataset_path_1,
        batch_size,
        0,
        drop_last=True)

    teacher_model_1 = DANN_GRL.DANN_GRL_Resnet(
        ResNet.resnet101, True,
        source_dataloader.dataset.num_classes).to(device)
    student_model = DANN_GRL.DANN_GRL_Resnet(
        ResNet.resnet34, True,
        source_dataloader.dataset.num_classes).to(device)
    targets_dataloader = [targets_dataloader]
    targets_testloader = [targets_testloader]

    if torch.cuda.device_count() > 1:
        teacher_model_1 = nn.DataParallel(teacher_model_1).to(device)
        student_model = nn.DataParallel(student_model).to(device)

    teacher_models = [teacher_model_1]

    growth_rate = torch.zeros(1)
    if init_beta != 0.0:
        growth_rate = torch.log(torch.FloatTensor(
            [end_beta / init_beta])) / torch.FloatTensor([epochs])

    optimizer_da_1 = torch.optim.SGD(teacher_model_1.parameters(),
                                     init_lr_da,
                                     momentum=momentum,
                                     weight_decay=weight_decay)

    optimizer_das = [optimizer_da_1]

    optimizer_kd = torch.optim.SGD(student_model.parameters(),
                                   init_lr_kd,
                                   momentum=momentum,
                                   weight_decay=weight_decay)

    scheduler_kd = None
    if scheduler_kd_fn is not None:
        scheduler_kd = scheduler_kd_fn(optimizer_kd, scheduler_kd_steps,
                                       scheduler_kd_gamma)

    best_student_acc = grl_multi_target_hinton_alt(
        init_lr_da,
        init_lr_kd,
        device,
        epochs,
        T,
        alpha,
        gamma,
        growth_rate,
        init_beta,
        source_dataloader,
        targets_dataloader,
        targets_testloader,
        optimizer_das,
        optimizer_kd,
        teacher_models,
        student_model,
        logger=None,
        is_scheduler_da=is_scheduler_da,
        is_scheduler_kd=is_scheduler_kd,
        scheduler_kd=None,
        scheduler_da=None,
        is_debug=False,
        batch_norm=batch_norm)