Example #1
0
def main(args):
    if args.model == 'gta':
        model_file = './trained_models/gta_source.pt'
        out_file = './trained_models/gta_wdgrl.pt'
        out_ftrs = 4375

        clf_model = GTANet().to(device)
        clf_model.load_state_dict(torch.load(model_file))
        feature_extractor = clf_model.feature_extractor
        discriminator = clf_model.classifier

    elif args.model == 'gta-res':
        model_file = './trained_models/gta_res_source.pt'
        out_file = './trained_models/gta_res_wdgrl.pt'

        clf_model = GTARes18Net(9, pretrained=False).to(device)
        out_ftrs = clf_model.fc.in_features
        clf_model.load_state_dict(torch.load(model_file))

        feature_extractor = clf_model.feature_extractor
        discriminator = clf_model.fc

    elif args.model == 'gta-vgg':
        model_file = './trained_models/gta_vgg_source.pt'
        out_file = './trained_models/gta_vgg_wdgrl.pt'

        clf_model = GTAVGG11Net(9, pretrained=False).to(device)
        out_ftrs = clf_model.classifier[0].in_features  # should be 512 * 7 * 7
        clf_model.load_state_dict(torch.load(model_file))
        set_requires_grad(clf_model, False)

        feature_extractor = clf_model.feature_extractor
        discriminator = clf_model.classifier

    else:
        raise ValueError(f'Unknown model type {args.model}')

    critic = nn.Sequential(
        nn.Linear(out_ftrs, 64),
        nn.ReLU(),
        nn.Linear(64, 16),
        nn.ReLU(),
        nn.Linear(16, 1),
    ).to(device)

    half_batch = args.batch_size // 2
    target_dataset = ImageFolder('./data',
                                 transform=Compose([
                                     Resize((398, 224)),
                                     RandomCrop(224),
                                     RandomHorizontalFlip(),
                                     ToTensor(),
                                     Normalize([0.485, 0.456, 0.406],
                                               [0.229, 0.224, 0.225]),
                                 ]))
    target_loader = DataLoader(target_dataset,
                               batch_size=half_batch,
                               shuffle=True,
                               num_workers=1,
                               pin_memory=True)

    source_dataset = ImageFolder('./t_data',
                                 transform=Compose([
                                     RandomCrop(224,
                                                pad_if_needed=True,
                                                padding_mode='reflect'),
                                     RandomHorizontalFlip(),
                                     ToTensor(),
                                     Normalize([0.485, 0.456, 0.406],
                                               [0.229, 0.224, 0.225]),
                                 ]))
    source_loader = DataLoader(source_dataset,
                               batch_size=half_batch,
                               shuffle=True,
                               num_workers=1,
                               pin_memory=True)

    critic_optim = torch.optim.Adam(critic.parameters(), lr=1e-4)
    clf_optim = torch.optim.Adam(clf_model.parameters(), lr=1e-4)
    clf_criterion = nn.CrossEntropyLoss()

    for epoch in range(1, args.epochs + 1):
        batch_iterator = zip(loop_iterable(source_loader),
                             loop_iterable(target_loader))

        total_loss = 0
        for _ in trange(args.iterations, leave=False):
            (source_x, source_y), (target_x, _) = next(batch_iterator)
            # Train critic
            set_requires_grad(feature_extractor, requires_grad=False)
            set_requires_grad(critic, requires_grad=True)

            source_x, target_x = source_x.to(device), target_x.to(device)
            source_y = source_y.to(device)

            with torch.no_grad():
                h_s = feature_extractor(source_x).data.view(
                    source_x.shape[0], -1)
                h_t = feature_extractor(target_x).data.view(
                    target_x.shape[0], -1)
            for _ in range(args.k_critic):
                gp = gradient_penalty(critic, h_s, h_t)

                critic_s = critic(h_s)
                critic_t = critic(h_t)
                wasserstein_distance = critic_s.mean() - critic_t.mean()

                critic_cost = -wasserstein_distance + args.gamma * gp

                critic_optim.zero_grad()
                critic_cost.backward()
                critic_optim.step()

                total_loss += critic_cost.item()

            # Train classifier
            set_requires_grad(feature_extractor, requires_grad=True)
            set_requires_grad(critic, requires_grad=False)
            for _ in range(args.k_clf):
                source_features = feature_extractor(source_x).view(
                    source_x.shape[0], -1)
                target_features = feature_extractor(target_x).view(
                    target_x.shape[0], -1)

                source_preds = discriminator(source_features)
                clf_loss = clf_criterion(source_preds, source_y)
                wasserstein_distance = critic(source_features).mean() - critic(
                    target_features).mean()

                loss = clf_loss + args.wd_clf * wasserstein_distance
                clf_optim.zero_grad()
                loss.backward()
                clf_optim.step()

        mean_loss = total_loss / (args.iterations * args.k_critic)
        tqdm.write(f'EPOCH {epoch:03d}: critic_loss={mean_loss:.4f}')
        torch.save(clf_model.state_dict(), out_file)
Example #2
0
def main(args):
    clf_model = Net().to(device)
    clf_model.load_state_dict(torch.load(args.MODEL_FILE))

    feature_extractor = clf_model.feature_extractor
    discriminator = clf_model.classifier

    critic = nn.Sequential(nn.Linear(320, 50), nn.ReLU(), nn.Linear(50, 20),
                           nn.ReLU(), nn.Linear(20, 1)).to(device)

    half_batch = args.batch_size // 2
    source_dataset = MNIST(config.DATA_DIR / 'mnist',
                           train=True,
                           download=True,
                           transform=Compose([GrayscaleToRgb(),
                                              ToTensor()]))
    source_loader = DataLoader(source_dataset,
                               batch_size=half_batch,
                               drop_last=True,
                               shuffle=True,
                               num_workers=0,
                               pin_memory=True)

    target_dataset = MNISTM(train=False)
    target_loader = DataLoader(target_dataset,
                               batch_size=half_batch,
                               drop_last=True,
                               shuffle=True,
                               num_workers=0,
                               pin_memory=True)

    critic_optim = torch.optim.Adam(critic.parameters(), lr=1e-4)
    clf_optim = torch.optim.Adam(clf_model.parameters(), lr=1e-4)
    clf_criterion = nn.CrossEntropyLoss()

    for epoch in range(1, args.epochs + 1):
        batch_iterator = zip(loop_iterable(source_loader),
                             loop_iterable(target_loader))

        total_loss = 0
        total_accuracy = 0
        for _ in trange(args.iterations, leave=False):
            (source_x, source_y), (target_x, _) = next(batch_iterator)
            # Train critic
            set_requires_grad(feature_extractor, requires_grad=False)
            set_requires_grad(critic, requires_grad=True)

            source_x, target_x = source_x.to(device), target_x.to(device)
            source_y = source_y.to(device)

            with torch.no_grad():
                h_s = feature_extractor(source_x).data.view(
                    source_x.shape[0], -1)
                h_t = feature_extractor(target_x).data.view(
                    target_x.shape[0], -1)
            for _ in range(args.k_critic):
                gp = gradient_penalty(critic, h_s, h_t)

                critic_s = critic(h_s)
                critic_t = critic(h_t)
                wasserstein_distance = critic_s.mean() - critic_t.mean()

                critic_cost = -wasserstein_distance + args.gamma * gp

                critic_optim.zero_grad()
                critic_cost.backward()
                critic_optim.step()

                total_loss += critic_cost.item()

            # Train classifier
            set_requires_grad(feature_extractor, requires_grad=True)
            set_requires_grad(critic, requires_grad=False)
            for _ in range(args.k_clf):
                source_features = feature_extractor(source_x).view(
                    source_x.shape[0], -1)
                target_features = feature_extractor(target_x).view(
                    target_x.shape[0], -1)

                source_preds = discriminator(source_features)
                clf_loss = clf_criterion(source_preds, source_y)
                wasserstein_distance = critic(source_features).mean() - critic(
                    target_features).mean()

                loss = clf_loss + args.wd_clf * wasserstein_distance
                clf_optim.zero_grad()
                loss.backward()
                clf_optim.step()

        mean_loss = total_loss / (args.iterations * args.k_critic)
        tqdm.write(f'EPOCH {epoch:03d}: critic_loss={mean_loss:.4f}')
        torch.save(clf_model.state_dict(), 'trained_models/wdgrl.pt')
Example #3
0
def main(args):
    final_accs = []
    source_models = [Net().to(device) for _ in range(10)]
    for idx in range(len(source_models)):
        source_models[idx].load_state_dict(torch.load(args.MODEL_FILE))
        source_models[idx].eval()
        set_requires_grad(source_models[idx], requires_grad=False)

    clfs = [source_model for source_model in source_models]
    source_models = [
        source_model.feature_extractor for source_model in source_models
    ]

    target_models = [Net().to(device) for _ in range(10)]
    for idx in range(len(target_models)):
        target_models[idx].load_state_dict(torch.load(args.MODEL_FILE))
        target_models[idx] = target_models[idx].feature_extractor

    discriminators = [
        nn.Sequential(nn.Linear(EXTRACTED_FEATURE_DIM, 64), nn.ReLU(),
                      nn.BatchNorm1d(64), nn.Linear(64, 1),
                      nn.Sigmoid()).to(device) for _ in range(10)
    ]

    batch_size = args.batch_size
    discriminator_optims = [
        torch.optim.Adam(discriminators[idx].parameters(), lr=1e-5)
        for idx in range(10)
    ]
    target_optims = [
        torch.optim.Adam(target_models[idx].parameters(), lr=1e-5)
        for idx in range(10)
    ]
    criterion = nn.BCEWithLogitsLoss()

    source_loaders = []
    target_loaders = []
    for idx in range(10):
        X_source, y_source = preprocess_train_single(idx)
        source_dataset = torch.utils.data.TensorDataset(X_source, y_source)

        source_loader = DataLoader(source_dataset,
                                   batch_size=batch_size,
                                   shuffle=False,
                                   num_workers=1,
                                   pin_memory=True)
        source_loaders.append(source_loader)

        X_target, y_target = preprocess_test(args.person)
        target_dataset = torch.utils.data.TensorDataset(X_target, y_target)
        target_loader = DataLoader(target_dataset,
                                   batch_size=batch_size,
                                   shuffle=False,
                                   num_workers=1,
                                   pin_memory=True)
        target_loaders.append(target_loader)

    best_voting_acc = test_all(clfs)
    best_tar_accs = [0.0] * 10

    for epoch in range(1, args.epochs + 1):
        source_loaders = [
            DataLoader(source_loaders[idx].dataset,
                       batch_size=batch_size,
                       shuffle=True) for idx in range(10)
        ]
        target_loaders = [
            DataLoader(target_loaders[idx].dataset,
                       batch_size=batch_size,
                       shuffle=True) for idx in range(10)
        ]
        for idx in range(10):
            source_loader = source_loaders[idx]
            target_loader = target_loaders[idx]
            batch_iterator = zip(loop_iterable(source_loader),
                                 loop_iterable(target_loader))
            target_model = target_models[idx]
            discriminator = discriminators[idx]
            source_model = source_models[idx]
            clf = clfs[idx]
            total_loss = 0
            adv_loss = 0
            total_accuracy = 0
            second_acc = 0
            for _ in trange(args.iterations, leave=False):
                # Train discriminator
                set_requires_grad(target_model, requires_grad=False)
                set_requires_grad(discriminator, requires_grad=True)
                discriminator.train()
                for _ in range(args.k_disc):
                    (source_x, _), (target_x, _) = next(batch_iterator)
                    source_x, target_x = source_x.to(device), target_x.to(
                        device)

                    source_features = source_model(source_x).view(
                        source_x.shape[0], -1)
                    target_features = target_model(target_x).view(
                        target_x.shape[0], -1)

                    discriminator_x = torch.cat(
                        [source_features, target_features])
                    discriminator_y = torch.cat([
                        torch.ones(source_x.shape[0], device=device),
                        torch.zeros(target_x.shape[0], device=device)
                    ])

                    preds = discriminator(discriminator_x).squeeze()
                    loss = criterion(preds, discriminator_y)

                    discriminator_optims[idx].zero_grad()
                    loss.backward()
                    discriminator_optims[idx].step()

                    total_loss += loss.item()
                    total_accuracy += ((preds >= 0.5).long(
                    ) == discriminator_y.long()).float().mean().item()

                # Train classifier
                set_requires_grad(target_model, requires_grad=True)
                set_requires_grad(discriminator, requires_grad=False)
                target_model.train()
                for _ in range(args.k_clf):
                    _, (target_x, _) = next(batch_iterator)
                    target_x = target_x.to(device)
                    target_features = target_model(target_x).view(
                        target_x.shape[0], -1)

                    # flipped labels
                    discriminator_y = torch.ones(target_x.shape[0],
                                                 device=device)

                    preds = discriminator(target_features).squeeze()
                    second_acc += ((preds >= 0.5).long() == discriminator_y.
                                   long()).float().mean().item()

                    loss = criterion(preds, discriminator_y)
                    adv_loss += loss.item()

                    target_optims[idx].zero_grad()
                    loss.backward()
                    target_optims[idx].step()

            mean_loss = total_loss / (args.iterations * args.k_disc)
            mean_adv_loss = adv_loss / (args.iterations * args.k_clf)
            dis_accuracy = total_accuracy / (args.iterations * args.k_disc)
            sec_acc = second_acc / (args.iterations * args.k_clf)
            clf.feature_extractor = target_model
            tar_accuarcy = test(args, clf)
            if tar_accuarcy > best_tar_accs[idx]:
                best_tar_accs[idx] = tar_accuarcy
                torch.save(clf.state_dict(),
                           'trained_models/adda' + str(idx) + '.pt')

            tqdm.write(
                f'EPOCH {epoch:03d}: discriminator_loss={mean_loss:.4f}, adv_loss = {mean_adv_loss:.4f}, '
                f'discriminator_accuracy={dis_accuracy:.4f}, tar_accuary = {tar_accuarcy:.4f}, best_accuracy = {best_tar_accs[idx]:.4f}, sec_acc = {sec_acc:.4f}'
            )

            # Create the full target model and save it
            clf.feature_extractor = target_model
            #torch.save(clf.state_dict(), 'trained_models/adda.pt')
        acc = test_all(clfs)
        final_accs.append(acc)
        if acc > best_voting_acc:
            best_voting_acc = acc
        print("In epoch %d, voting_acc: %.4f, best_voting_acc: %.4f" %
              (epoch, acc, best_voting_acc))
    jd = {"test_acc": final_accs}
    with open(str(args.seed) + '/acc' + str(args.person) + '.json', 'w') as f:
        json.dump(jd, f)
Example #4
0
def main(args):
    source_model = Net().to(device)
    source_model.load_state_dict(torch.load(args.MODEL_FILE))
    source_model.eval()
    set_requires_grad(source_model, requires_grad=False)

    clf = source_model
    source_model = source_model.feature_extractor

    target_model = Net().to(device)
    target_model.load_state_dict(torch.load(args.MODEL_FILE))
    target_model = target_model.feature_extractor
    target_clf = clf.classifier

    discriminator = nn.Sequential(nn.Linear(320, 50), nn.ReLU(),
                                  nn.Linear(50, 20), nn.ReLU(),
                                  nn.Linear(20, 1)).to(device)

    half_batch = args.batch_size // 2
    source_dataset = MNIST(config.DATA_DIR / 'mnist',
                           train=True,
                           download=True,
                           transform=Compose([GrayscaleToRgb(),
                                              ToTensor()]))
    source_loader = DataLoader(source_dataset,
                               batch_size=half_batch,
                               shuffle=True,
                               num_workers=1,
                               pin_memory=True)

    target_dataset = MNISTM(train=False)
    target_loader = DataLoader(target_dataset,
                               batch_size=half_batch,
                               shuffle=True,
                               num_workers=1,
                               pin_memory=True)

    discriminator_optim = torch.optim.Adam(discriminator.parameters())
    target_optim = torch.optim.Adam(target_model.parameters())
    criterion = nn.BCEWithLogitsLoss()

    for epoch in range(1, args.epochs + 1):
        batch_iterator = zip(loop_iterable(source_loader),
                             loop_iterable(target_loader))

        total_loss = 0
        total_accuracy = 0
        target_label_accuracy = 0
        for _ in trange(args.iterations, leave=False):
            # Train discriminator
            set_requires_grad(target_model, requires_grad=False)
            set_requires_grad(discriminator, requires_grad=True)
            for _ in range(args.k_disc):
                (source_x, _), (target_x, _) = next(batch_iterator)
                source_x, target_x = source_x.to(device), target_x.to(device)

                source_features = source_model(source_x).view(
                    source_x.shape[0], -1)
                target_features = target_model(target_x).view(
                    target_x.shape[0], -1)

                discriminator_x = torch.cat([source_features, target_features])
                discriminator_y = torch.cat([
                    torch.ones(source_x.shape[0], device=device),
                    torch.zeros(target_x.shape[0], device=device)
                ])

                preds = discriminator(discriminator_x).squeeze()
                loss = criterion(preds, discriminator_y)

                discriminator_optim.zero_grad()
                loss.backward()
                discriminator_optim.step()

                total_loss += loss.item()
                total_accuracy += ((
                    preds >
                    0).long() == discriminator_y.long()).float().mean().item()

            # Train classifier
            set_requires_grad(target_model, requires_grad=True)
            set_requires_grad(discriminator, requires_grad=False)
            for _ in range(args.k_clf):
                _, (target_x, target_labels) = next(batch_iterator)
                target_x = target_x.to(device)
                target_features = target_model(target_x).view(
                    target_x.shape[0], -1)

                # flipped labels
                discriminator_y = torch.ones(target_x.shape[0], device=device)

                preds = discriminator(target_features).squeeze()
                loss = criterion(preds, discriminator_y)

                target_optim.zero_grad()
                loss.backward()
                target_optim.step()

                target_label_preds = target_clf(target_features)
                target_label_accuracy += (target_label_preds.cpu().max(1)[1] ==
                                          target_labels).float().mean().item()

        mean_loss = total_loss / (args.iterations * args.k_disc)
        mean_accuracy = total_accuracy / (args.iterations * args.k_disc)
        target_mean_accuracy = target_label_accuracy / (args.iterations *
                                                        args.k_clf)
        tqdm.write(
            f'EPOCH {epoch:03d}: discriminator_loss={mean_loss:.4f}, '
            f'discriminator_accuracy={mean_accuracy:.4f}, target_accuracy={target_mean_accuracy:.4f}'
        )

        # Create the full target model and save it
        clf.feature_extractor = target_model
        torch.save(clf.state_dict(), 'trained_models/adda.pt')
def main(args):
    if args.model == 'gta':
        model_file = './trained_models/gta_source.pt'
        out_file = './trained_models/gta_adda.pt'
        out_ftrs = 4375

        model = GTANet().to(device)
        model.load_state_dict(torch.load(model_file))
        model.eval()
        set_requires_grad(model, False)
        source_model = model.feature_extractor
        clf = model

        model_2 = GTANet().to(device)
        model_2.load_state_dict(torch.load(model_file))
        target_model = model_2.feature_extractor

    elif args.model == 'gta-res':
        model_file = './trained_models/gta_res_source.pt'
        out_file = './trained_models/gta_res_adda.pt'

        model = GTARes18Net(9, pretrained=False).to(device)
        out_ftrs = model.fc.in_features
        model.load_state_dict(torch.load(model_file))
        model.eval()
        set_requires_grad(model, False)

        source_model = model.feature_extractor
        clf = model

        model_2 = GTARes18Net(9, pretrained=False).to(device)
        model_2.load_state_dict(torch.load(model_file))
        target_model = model_2.feature_extractor

    elif args.model == 'gta-vgg':
        model_file = './trained_models/gta_vgg_source.pt'
        out_file = './trained_models/gta_vgg_adda.pt'

        model = GTAVGG11Net(9, pretrained=False).to(device)
        out_ftrs = model.classifier[0].in_features  # should be 512 * 7 * 7
        model.load_state_dict(torch.load(model_file))
        model.eval()
        set_requires_grad(model, False)

        def source_model(x):
            x = model.features(x)
            x = model.avgpool(x)
            x = torch.flatten(x, 1)
            return x

        clf = model

        model_2 = GTAVGG11Net(9, pretrained=False).to(device)
        model_2.load_state_dict(torch.load(model_file))

        def target_model(x):
            x = model_2.features(x)
            x = model_2.avgpool(x)
            x = torch.flatten(x, 1)
            return x

    else:
        raise ValueError(f'Unknown model type {args.model}')

    discriminator = nn.Sequential(
        nn.Linear(out_ftrs, 64),
        nn.ReLU(),
        nn.Linear(64, 16),
        nn.ReLU(),
        nn.Linear(16, 1),
    ).to(device)

    half_batch = args.batch_size // 2
    target_dataset = ImageFolder('./data',
                                 transform=Compose([
                                     Resize((398, 224)),
                                     RandomCrop(224),
                                     RandomHorizontalFlip(),
                                     ToTensor(),
                                     Normalize([0.485, 0.456, 0.406],
                                               [0.229, 0.224, 0.225]),
                                 ]))
    target_loader = DataLoader(target_dataset,
                               batch_size=half_batch,
                               shuffle=True,
                               num_workers=1,
                               pin_memory=True)

    source_dataset = ImageFolder('./t_data',
                                 transform=Compose([
                                     RandomCrop(224,
                                                pad_if_needed=True,
                                                padding_mode='reflect'),
                                     RandomHorizontalFlip(),
                                     ToTensor(),
                                     Normalize([0.485, 0.456, 0.406],
                                               [0.229, 0.224, 0.225]),
                                 ]))
    source_loader = DataLoader(source_dataset,
                               batch_size=half_batch,
                               shuffle=True,
                               num_workers=1,
                               pin_memory=True)

    discriminator_optim = torch.optim.Adam(discriminator.parameters())
    target_optim = torch.optim.Adam(model_2.parameters())
    criterion = nn.BCEWithLogitsLoss()

    for epoch in range(1, args.epochs + 1):
        batch_iterator = zip(loop_iterable(source_loader),
                             loop_iterable(target_loader))

        total_loss = 0
        total_accuracy = 0
        for _ in trange(args.iterations, leave=False):
            # Train discriminator
            set_requires_grad(model_2, requires_grad=False)
            set_requires_grad(discriminator, requires_grad=True)
            for _ in range(args.k_disc):
                (source_x, _), (target_x, _) = next(batch_iterator)
                source_x, target_x = source_x.to(device), target_x.to(device)

                source_features = source_model(source_x).view(
                    source_x.shape[0], -1)
                target_features = target_model(target_x).view(
                    target_x.shape[0], -1)

                discriminator_x = torch.cat([source_features, target_features])
                discriminator_y = torch.cat([
                    torch.ones(source_x.shape[0], device=device),
                    torch.zeros(target_x.shape[0], device=device)
                ])

                preds = discriminator(discriminator_x).squeeze()
                loss = criterion(preds, discriminator_y)

                discriminator_optim.zero_grad()
                loss.backward()
                discriminator_optim.step()

                total_loss += loss.item()
                total_accuracy += ((preds > 0).long() == discriminator_y.long()
                                  ).float().mean().item()

            # Train classifier
            set_requires_grad(model_2, requires_grad=True)
            set_requires_grad(discriminator, requires_grad=False)
            for _ in range(args.k_clf):
                _, (target_x, _) = next(batch_iterator)
                target_x = target_x.to(device)
                target_features = target_model(target_x).view(
                    target_x.shape[0], -1)

                # flipped labels
                discriminator_y = torch.ones(target_x.shape[0], device=device)

                preds = discriminator(target_features).squeeze()
                loss = criterion(preds, discriminator_y)

                target_optim.zero_grad()
                loss.backward()
                target_optim.step()

        mean_loss = total_loss / (args.iterations * args.k_disc)
        mean_accuracy = total_accuracy / (args.iterations * args.k_disc)
        tqdm.write(f'EPOCH {epoch:03d}: discriminator_loss={mean_loss:.4f}, '
                   f'discriminator_accuracy={mean_accuracy:.4f}')

        # Create the full target model and save it
        if args.model == 'gta':
            clf.feature_extractor = target_model
        elif args.model == 'gta-res':
            clf.conv1 = model_2.conv1
            clf.bn1 = model_2.bn1
            clf.relu = model_2.relu
            clf.maxpool = model_2.maxpool

            clf.layer1 = model_2.layer1
            clf.layer2 = model_2.layer2
            clf.layer3 = model_2.layer3
            clf.layer4 = model_2.layer4

            clf.avgpool = model_2.avgpool

        torch.save(clf.state_dict(), out_file)
Example #6
0
def main(args):
    source_model = Net().to(device)
    source_model.load_state_dict(torch.load(args.MODEL_FILE))
    source_model.eval()
    set_requires_grad(source_model, requires_grad=False)

    clf = source_model
    source_model = source_model.feature_extractor

    target_model = Net().to(device)
    target_model.load_state_dict(torch.load(args.MODEL_FILE))
    target_model = target_model.feature_extractor

    classifier = clf.classifier

    discriminator = nn.Sequential(nn.Linear(EXTRACTED_FEATURE_DIM, 64),
                                  nn.ReLU(), nn.BatchNorm1d(64),
                                  nn.Linear(64, 1), nn.Sigmoid()).to(device)

    #half_batch = args.batch_size // 2

    batch_size = args.batch_size

    # X_source, y_source = preprocess_train()
    X_source, y_source = preprocess_train_single(1)
    source_dataset = torch.utils.data.TensorDataset(X_source, y_source)

    source_loader = DataLoader(source_dataset,
                               batch_size=batch_size,
                               shuffle=False,
                               num_workers=1,
                               pin_memory=True)

    X_target, y_target = preprocess_test(args.person)
    target_dataset = torch.utils.data.TensorDataset(X_target, y_target)
    target_loader = DataLoader(target_dataset,
                               batch_size=batch_size,
                               shuffle=False,
                               num_workers=1,
                               pin_memory=True)

    discriminator_optim = torch.optim.Adam(discriminator.parameters())
    target_optim = torch.optim.Adam(target_model.parameters(), lr=3e-6)
    criterion = nn.BCEWithLogitsLoss()
    criterion_class = nn.CrossEntropyLoss()

    best_tar_acc = test(args, clf)
    final_accs = []

    for epoch in range(1, args.epochs + 1):
        source_loader = DataLoader(source_loader.dataset,
                                   batch_size=batch_size,
                                   shuffle=True)
        target_loader = DataLoader(target_loader.dataset,
                                   batch_size=batch_size,
                                   shuffle=True)
        batch_iterator = zip(loop_iterable(source_loader),
                             loop_iterable(target_loader))

        total_loss = 0
        adv_loss = 0
        total_accuracy = 0
        second_acc = 0
        total_class_loss = 0
        for _ in trange(args.iterations, leave=False):
            # Train discriminator
            set_requires_grad(target_model, requires_grad=False)
            set_requires_grad(discriminator, requires_grad=True)
            discriminator.train()
            for _ in range(args.k_disc):
                (source_x, source_y), (target_x, _) = next(batch_iterator)
                source_y = source_y.to(device).view(-1)
                source_x, target_x = source_x.to(device), target_x.to(device)

                source_features = source_model(source_x).view(
                    source_x.shape[0], -1)
                target_features = target_model(target_x).view(
                    target_x.shape[0], -1)

                discriminator_x = torch.cat([source_features, target_features])
                discriminator_y = torch.cat([
                    torch.ones(source_x.shape[0], device=device),
                    torch.zeros(target_x.shape[0], device=device)
                ])

                preds = discriminator(discriminator_x).squeeze()
                loss = criterion(preds, discriminator_y)

                discriminator_optim.zero_grad()
                loss.backward()
                discriminator_optim.step()

                total_loss += loss.item()
                total_accuracy += ((preds >= 0.5).long() == discriminator_y.
                                   long()).float().mean().item()

            # Train feature extractor
            set_requires_grad(target_model, requires_grad=True)
            set_requires_grad(discriminator, requires_grad=False)
            target_model.train()
            for _ in range(args.k_clf):
                _, (target_x, _) = next(batch_iterator)
                target_x = target_x.to(device)
                target_features = target_model(target_x).view(
                    target_x.shape[0], -1)
                source_features = target_model(source_x).view(
                    source_x.shape[0], -1)
                source_pred = classifier(source_features)  # (batch_size, 4)

                # flipped labels
                discriminator_y = torch.ones(target_x.shape[0], device=device)

                preds = discriminator(target_features).squeeze()
                second_acc += ((preds >= 0.5).long() == discriminator_y.long()
                               ).float().mean().item()

                loss_adv = criterion(preds, discriminator_y)
                adv_loss += loss_adv.item()
                loss_class = criterion_class(source_pred, source_y)
                total_class_loss += loss_class.item()
                loss = loss_adv  #+ 0.001*loss_class

                target_optim.zero_grad()
                loss.backward()
                target_optim.step()

        mean_loss = total_loss / (args.iterations * args.k_disc)
        mean_adv_loss = adv_loss / (args.iterations * args.k_clf)
        total_class_loss = total_class_loss / (args.iterations * args.k_clf)
        dis_accuracy = total_accuracy / (args.iterations * args.k_disc)
        sec_acc = second_acc / (args.iterations * args.k_clf)
        clf.feature_extractor = target_model
        tar_accuarcy = test(args, clf)
        final_accs.append(tar_accuarcy)
        if tar_accuarcy > best_tar_acc:
            best_tar_acc = tar_accuarcy
            torch.save(clf.state_dict(), 'trained_models/adda.pt')

        tqdm.write(
            f'EPOCH {epoch:03d}: discriminator_loss={mean_loss:.4f}, adv_loss = {mean_adv_loss:.4f}, '
            f'discriminator_accuracy={dis_accuracy:.4f}, tar_accuary = {tar_accuarcy:.4f}, best_accuracy = {best_tar_acc:.4f}, '
            f'sec_acc = {sec_acc:.4f}, total_class_loss: {total_class_loss:.4f}'
        )

        # Create the full target model and save it
        clf.feature_extractor = target_model
        #torch.save(clf.state_dict(), 'trained_models/adda.pt')
    jd = {"test_acc": final_accs}
    with open(str(args.seed) + '/acc' + str(args.person) + '.json', 'w') as f:
        json.dump(jd, f)
def main(args):
    clf_model = Net().to(device)
    clf_model.load_state_dict(torch.load(args.MODEL_FILE))

    feature_extractor = clf_model.feature_extractor
    discriminator = clf_model.classifier

    critic = nn.Sequential(nn.Linear(320, 50), nn.ReLU(), nn.Linear(50, 20),
                           nn.ReLU(), nn.Linear(20, 1)).to(device)

    half_batch = args.batch_size // 2
    if args.adapt_setting == 'mnist2mnistm':
        source_dataset = MNIST(config.DATA_DIR / 'mnist',
                               train=True,
                               download=True,
                               transform=Compose(
                                   [GrayscaleToRgb(),
                                    ToTensor()]))
        target_dataset = MNISTM(train=False)
    elif args.adapt_setting == 'svhn2mnist':
        source_dataset = ImageClassdata(txt_file=args.src_list,
                                        root_dir=args.src_root,
                                        img_type=args.img_type,
                                        transform=transforms.Compose([
                                            transforms.Resize(28),
                                            transforms.ToTensor(),
                                        ]))
        target_dataset = ImageClassdata(txt_file=args.tar_list,
                                        root_dir=args.tar_root,
                                        img_type=args.img_type,
                                        transform=transforms.Compose([
                                            transforms.ToTensor(),
                                        ]))
    elif args.adapt_setting == 'mnist2usps':
        source_dataset = ImageClassdata(txt_file=args.src_list,
                                        root_dir=args.src_root,
                                        img_type=args.img_type,
                                        transform=transforms.Compose([
                                            transforms.ToTensor(),
                                        ]))
        target_dataset = ImageClassdata(txt_file=args.tar_list,
                                        root_dir=args.tar_root,
                                        img_type=args.img_type,
                                        transform=transforms.Compose([
                                            transforms.Resize(28),
                                            transforms.ToTensor(),
                                        ]))
    else:
        raise NotImplementedError
    source_loader = DataLoader(source_dataset,
                               batch_size=half_batch,
                               shuffle=True,
                               num_workers=1,
                               pin_memory=True,
                               drop_last=True)
    target_loader = DataLoader(target_dataset,
                               batch_size=half_batch,
                               shuffle=True,
                               num_workers=1,
                               pin_memory=True,
                               drop_last=True)

    critic_optim = torch.optim.Adam(critic.parameters(), lr=1e-4)
    clf_optim = torch.optim.Adam(clf_model.parameters(), lr=1e-4)
    clf_criterion = nn.CrossEntropyLoss()

    if not os.path.exists('logs'): os.makedirs('logs')
    f = open(f'logs/{args.adapt_setting}_{args.name}.txt', 'w+')

    for epoch in range(1, args.epochs + 1):
        batch_iterator = zip(loop_iterable(source_loader),
                             loop_iterable(target_loader))

        total_loss = 0
        total_accuracy = 0
        target_label_accuracy = 0
        for _ in trange(args.iterations, leave=False):
            (source_x, source_y), (target_x, target_y) = next(batch_iterator)
            # Train critic
            set_requires_grad(feature_extractor, requires_grad=False)
            set_requires_grad(critic, requires_grad=True)

            source_x, target_x = source_x.to(device), target_x.to(device)
            source_y = source_y.to(device)

            with torch.no_grad():
                h_s = feature_extractor(source_x).data.view(
                    source_x.shape[0], -1)
                h_t = feature_extractor(target_x).data.view(
                    target_x.shape[0], -1)
            for _ in range(args.k_critic):
                gp = gradient_penalty(critic, h_s, h_t)

                critic_s = critic(h_s)
                critic_t = critic(h_t)
                wasserstein_distance = critic_s.mean() - critic_t.mean()

                critic_cost = -wasserstein_distance + args.gamma * gp

                critic_optim.zero_grad()
                critic_cost.backward()
                critic_optim.step()

                total_loss += critic_cost.item()

            # Train classifier
            set_requires_grad(feature_extractor, requires_grad=True)
            set_requires_grad(critic, requires_grad=False)
            for _ in range(args.k_clf):
                source_features = feature_extractor(source_x).view(
                    source_x.shape[0], -1)
                target_features = feature_extractor(target_x).view(
                    target_x.shape[0], -1)

                source_preds = discriminator(source_features)
                clf_loss = clf_criterion(source_preds, source_y)
                wasserstein_distance = critic(source_features).mean() - critic(
                    target_features).mean()

                loss = clf_loss + args.wd_clf * wasserstein_distance
                clf_optim.zero_grad()
                loss.backward()
                clf_optim.step()

                target_preds = discriminator(target_features)
                target_label_accuracy += (target_preds.cpu().max(1)[1] ==
                                          target_y).float().mean().item()

        mean_loss = total_loss / (args.iterations * args.k_critic)
        # mean_accuracy = total_accuracy / (args.iterations * args.k_disc)
        target_mean_accuracy = target_label_accuracy / (args.iterations *
                                                        args.k_clf)
        tqdm.write(
            f'EPOCH {epoch:03d}: critic_loss={mean_loss:.4f}, target_accuracy={target_mean_accuracy:.4f}'
        )
        f.write(
            f'EPOCH {epoch:03d}: critic_loss={mean_loss:.4f}, target_accuracy={target_mean_accuracy:.4f}'
        )
        torch.save(
            clf_model.state_dict(),
            f'trained_models/{args.adapt_setting}_{args.name}_ep{epoch}.pt')
    f.close()