Example #1
0
def main():


    batch_size = 16
    test_batch_size = 32

    p = os.path.expanduser("./image-clef/p")
    c = os.path.expanduser("./image-clef/c")
    i = os.path.expanduser("./image-clef/i")
    is_debug = False

    epochs = 400
    init_lr_da = 0.0001
    init_lr_kd = 0.0005
    momentum = 0.9
    weight_decay = 5e-4
    device = torch.device("cuda")
    alpha = 0.5
    gamma = 0.5
    init_beta = 0.1
    end_beta = 0.9

    student_pretrained = True

    source_dataloader, target_dataloader, target_testloader = DA_datasets.get_source_target_loader("ImageClef",
                                                                                                   p,
                                                                                                   c,
                                                                                                   batch_size, 0)
    if torch.cuda.device_count() > 1:
        teacher_model = nn.DataParallel(DANN_GRL.DANN_GRL_Resnet(ResNet.resnet101, True, source_dataloader.dataset.num_classes)).to(device)
        student_model = nn.DataParallel(DANN_GRL.DANN_GRL_Resnet(ResNet.resnet34, student_pretrained, source_dataloader.dataset.num_classes)).to(device)
    else:
        teacher_model = DANN_GRL.DANN_GRL_Resnet(ResNet.resnet101, True, source_dataloader.dataset.num_classes).to(device)
        student_model = DANN_GRL.DANN_GRL_Resnet(ResNet.resnet34, student_pretrained, source_dataloader.dataset.num_classes).to(device)

    distil = distiller.Distiller(teacher_model, student_model)
    distil = distil.to(device)
    criterion = nn.CrossEntropyLoss()

    growth_rate = torch.log(torch.FloatTensor([end_beta / init_beta])) / torch.FloatTensor([epochs])


    optimizer_da = torch.optim.SGD(distil.t_net.parameters(), init_lr_da,
                                momentum=momentum, weight_decay=weight_decay)

    optimizer_kd = torch.optim.SGD(list(distil.s_net.parameters()) + list(distil.Connectors.parameters()), init_lr_kd,
                                momentum=momentum, weight_decay=weight_decay)

    mmd_hinton_alt(init_lr_da, init_lr_kd, device, epochs, alpha, gamma, growth_rate, init_beta, source_dataloader,
               target_dataloader, target_testloader, optimizer_da, optimizer_kd, distil, criterion, is_scheduler_kd=False, is_scheduler_da=True, is_debug=False, best_teacher_acc = 0)
Example #2
0
def main():

    a = os.path.expanduser('./datasets/amazon/images')
    w = os.path.expanduser('./datasets/webcam/images')
    d = os.path.expanduser('./datasets/dslr/images')

    dataset_name = 'Office31'

    source_dataset_path = a
    target_dataset_path_1 = w

    init_beta = 0.1
    end_beta = 0.6
    init_lr_da = 0.001
    init_lr_kd = 0.01
    momentum = 0.9
    T = 20
    batch_size = 32
    alpha = 0.5
    gamma = 0.5
    epochs = 400
    scheduler_kd_fn = None
    batch_norm = True
    device = 'cuda'
    weight_decay = 5e-4
    is_scheduler_da = True
    is_scheduler_kd = True

    source_dataloader, targets_dataloader, targets_testloader = DA_datasets.get_source_target_loader(
        dataset_name,
        source_dataset_path,
        target_dataset_path_1,
        batch_size,
        0,
        drop_last=True)

    teacher_model_1 = DANN_GRL.DANN_GRL_Resnet(
        ResNet.resnet101, True,
        source_dataloader.dataset.num_classes).to(device)
    student_model = DANN_GRL.DANN_GRL_Resnet(
        ResNet.resnet34, True,
        source_dataloader.dataset.num_classes).to(device)
    targets_dataloader = [targets_dataloader]
    targets_testloader = [targets_testloader]

    if torch.cuda.device_count() > 1:
        teacher_model_1 = nn.DataParallel(teacher_model_1).to(device)
        student_model = nn.DataParallel(student_model).to(device)

    teacher_models = [teacher_model_1]

    growth_rate = torch.zeros(1)
    if init_beta != 0.0:
        growth_rate = torch.log(torch.FloatTensor(
            [end_beta / init_beta])) / torch.FloatTensor([epochs])

    optimizer_da_1 = torch.optim.SGD(teacher_model_1.parameters(),
                                     init_lr_da,
                                     momentum=momentum,
                                     weight_decay=weight_decay)

    optimizer_das = [optimizer_da_1]

    optimizer_kd = torch.optim.SGD(student_model.parameters(),
                                   init_lr_kd,
                                   momentum=momentum,
                                   weight_decay=weight_decay)

    scheduler_kd = None
    if scheduler_kd_fn is not None:
        scheduler_kd = scheduler_kd_fn(optimizer_kd, scheduler_kd_steps,
                                       scheduler_kd_gamma)

    best_student_acc = grl_multi_target_hinton_alt(
        init_lr_da,
        init_lr_kd,
        device,
        epochs,
        T,
        alpha,
        gamma,
        growth_rate,
        init_beta,
        source_dataloader,
        targets_dataloader,
        targets_testloader,
        optimizer_das,
        optimizer_kd,
        teacher_models,
        student_model,
        logger=None,
        is_scheduler_da=is_scheduler_da,
        is_scheduler_kd=is_scheduler_kd,
        scheduler_kd=None,
        scheduler_da=None,
        is_debug=False,
        batch_norm=batch_norm)
Example #3
0
def main():

    useVisdomLogger = True

    batch_size = 32
    test_batch_size = 32

    webcam = os.path.expanduser("~/datasets/webcam/images")
    amazon = os.path.expanduser("~/datasets/amazon/images")
    dslr = os.path.expanduser("~/datasets/dslr/images")
    is_debug = False

    epochs = 400
    init_lr_da = 0.001
    init_lr_kd = 0.001
    momentum = 0.9
    weight_decay = 5e-4
    device = torch.device("cuda")
    T = 20
    alpha = 0.3
    init_beta = 0.1
    end_beta = 0.9

    student_pretrained = True

    if torch.cuda.device_count() > 1:
        teacher_model = nn.DataParallel(
            DAN_model.DANNet_ResNet(ResNet.resnet50, True)).to(device)
        student_model = nn.DataParallel(
            DAN_model.DANNet_ResNet(ResNet.resnet34,
                                    student_pretrained)).to(device)
    else:
        teacher_model = DAN_model.DANNet_ResNet(ResNet.resnet50,
                                                True).to(device)
        student_model = DAN_model.DANNet_ResNet(ResNet.resnet34,
                                                student_pretrained).to(device)

    growth_rate = torch.log(torch.FloatTensor(
        [end_beta / init_beta])) / torch.FloatTensor([epochs])

    optimizer_da = torch.optim.SGD(list(teacher_model.parameters()) +
                                   list(student_model.parameters()),
                                   init_lr_da,
                                   momentum=momentum,
                                   weight_decay=weight_decay)

    optimizer_kd = torch.optim.SGD(list(teacher_model.parameters()) +
                                   list(student_model.parameters()),
                                   init_lr_kd,
                                   momentum=momentum,
                                   weight_decay=weight_decay)

    source_dataloader, target_dataloader, target_testloader = DA_datasets.get_source_target_loader(
        "Office31", amazon, webcam, batch_size, 0)
    logger = None
    if useVisdomLogger:
        logger = VisdomLogger(port=9000)
    logger = LoggerForSacred(logger, always_print=True)

    mmd_hinton_alt(init_lr_da,
                   init_lr_kd,
                   device,
                   epochs,
                   T,
                   alpha,
                   growth_rate,
                   init_beta,
                   source_dataloader,
                   target_dataloader,
                   target_testloader,
                   optimizer_da,
                   optimizer_kd,
                   teacher_model,
                   student_model,
                   logger=logger,
                   is_scheduler_kd=False,
                   is_scheduler_da=True,
                   is_debug=False)