def create_dataloaders(batch_size):
    dataset = MNIST(config.DATA_DIR / 'mnist',
                    train=True,
                    download=True,
                    transform=Compose([GrayscaleToRgb(),
                                       ToTensor()]))
    shuffled_indices = np.random.permutation(len(dataset))
    train_idx = shuffled_indices[:int(0.8 * len(dataset))]
    val_idx = shuffled_indices[int(0.8 * len(dataset)):]

    train_loader = DataLoader(dataset,
                              batch_size=batch_size,
                              drop_last=True,
                              sampler=SubsetRandomSampler(train_idx),
                              num_workers=1,
                              pin_memory=True)
    val_loader = DataLoader(dataset,
                            batch_size=batch_size,
                            drop_last=False,
                            sampler=SubsetRandomSampler(val_idx),
                            num_workers=1,
                            pin_memory=True)
    return train_loader, val_loader
Beispiel #2
0
def main(args):
    model = Net().to(device)
    model.load_state_dict(torch.load(args.MODEL_FILE))
    feature_extractor = model.feature_extractor
    clf = model.classifier

    discriminator = nn.Sequential(GradientReversal(), nn.Linear(320, 50),
                                  nn.ReLU(), nn.Linear(50, 20), nn.ReLU(),
                                  nn.Linear(20, 1)).to(device)

    half_batch = args.batch_size // 2
    source_dataset = MNIST(config.DATA_DIR / 'mnist',
                           train=True,
                           download=True,
                           transform=Compose([GrayscaleToRgb(),
                                              ToTensor()]))
    source_loader = DataLoader(source_dataset,
                               batch_size=half_batch,
                               shuffle=True,
                               num_workers=1,
                               pin_memory=True)

    target_dataset = MNISTM(train=False)
    target_loader = DataLoader(target_dataset,
                               batch_size=half_batch,
                               shuffle=True,
                               num_workers=1,
                               pin_memory=True)

    optim = torch.optim.Adam(
        list(discriminator.parameters()) + list(model.parameters()))

    for epoch in range(1, args.epochs + 1):
        batches = zip(source_loader, target_loader)
        n_batches = min(len(source_loader), len(target_loader))

        total_domain_loss = total_label_accuracy = 0
        target_label_accuracy = 0
        for (source_x,
             source_labels), (target_x,
                              target_labels) in tqdm(batches,
                                                     leave=False,
                                                     total=n_batches):
            x = torch.cat([source_x, target_x])
            x = x.to(device)
            domain_y = torch.cat([
                torch.ones(source_x.shape[0]),
                torch.zeros(target_x.shape[0])
            ])
            domain_y = domain_y.to(device)
            label_y = source_labels.to(device)

            features = feature_extractor(x).view(x.shape[0], -1)
            domain_preds = discriminator(features).squeeze()
            label_preds = clf(features[:source_x.shape[0]])

            domain_loss = F.binary_cross_entropy_with_logits(
                domain_preds, domain_y)
            label_loss = F.cross_entropy(label_preds, label_y)
            loss = domain_loss + label_loss

            optim.zero_grad()
            loss.backward()
            optim.step()

            total_domain_loss += domain_loss.item()
            total_label_accuracy += (
                label_preds.max(1)[1] == label_y).float().mean().item()

            target_label_preds = clf(features[source_x.shape[0]:])
            target_label_accuracy += (target_label_preds.cpu().max(1)[1] ==
                                      target_labels).float().mean().item()

        mean_loss = total_domain_loss / n_batches
        mean_accuracy = total_label_accuracy / n_batches
        target_mean_accuracy = target_label_accuracy / n_batches
        tqdm.write(
            f'EPOCH {epoch:03d}: domain_loss={mean_loss:.4f}, '
            f'source_accuracy={mean_accuracy:.4f}, target_accuracy={target_mean_accuracy:.4f}'
        )

        torch.save(model.state_dict(), 'trained_models/revgrad.pt')
def main(args):
    clf_model = Net().to(device)
    clf_model.load_state_dict(torch.load(args.MODEL_FILE))

    feature_extractor = clf_model.feature_extractor
    discriminator = clf_model.classifier

    critic = nn.Sequential(nn.Linear(320, 50), nn.ReLU(), nn.Linear(50, 20),
                           nn.ReLU(), nn.Linear(20, 1)).to(device)

    half_batch = args.batch_size // 2
    source_dataset = MNIST(config.DATA_DIR / 'mnist',
                           train=True,
                           download=True,
                           transform=Compose([GrayscaleToRgb(),
                                              ToTensor()]))
    source_loader = DataLoader(source_dataset,
                               batch_size=half_batch,
                               drop_last=True,
                               shuffle=True,
                               num_workers=0,
                               pin_memory=True)

    target_dataset = MNISTM(train=False)
    target_loader = DataLoader(target_dataset,
                               batch_size=half_batch,
                               drop_last=True,
                               shuffle=True,
                               num_workers=0,
                               pin_memory=True)

    critic_optim = torch.optim.Adam(critic.parameters(), lr=1e-4)
    clf_optim = torch.optim.Adam(clf_model.parameters(), lr=1e-4)
    clf_criterion = nn.CrossEntropyLoss()

    for epoch in range(1, args.epochs + 1):
        batch_iterator = zip(loop_iterable(source_loader),
                             loop_iterable(target_loader))

        total_loss = 0
        total_accuracy = 0
        for _ in trange(args.iterations, leave=False):
            (source_x, source_y), (target_x, _) = next(batch_iterator)
            # Train critic
            set_requires_grad(feature_extractor, requires_grad=False)
            set_requires_grad(critic, requires_grad=True)

            source_x, target_x = source_x.to(device), target_x.to(device)
            source_y = source_y.to(device)

            with torch.no_grad():
                h_s = feature_extractor(source_x).data.view(
                    source_x.shape[0], -1)
                h_t = feature_extractor(target_x).data.view(
                    target_x.shape[0], -1)
            for _ in range(args.k_critic):
                gp = gradient_penalty(critic, h_s, h_t)

                critic_s = critic(h_s)
                critic_t = critic(h_t)
                wasserstein_distance = critic_s.mean() - critic_t.mean()

                critic_cost = -wasserstein_distance + args.gamma * gp

                critic_optim.zero_grad()
                critic_cost.backward()
                critic_optim.step()

                total_loss += critic_cost.item()

            # Train classifier
            set_requires_grad(feature_extractor, requires_grad=True)
            set_requires_grad(critic, requires_grad=False)
            for _ in range(args.k_clf):
                source_features = feature_extractor(source_x).view(
                    source_x.shape[0], -1)
                target_features = feature_extractor(target_x).view(
                    target_x.shape[0], -1)

                source_preds = discriminator(source_features)
                clf_loss = clf_criterion(source_preds, source_y)
                wasserstein_distance = critic(source_features).mean() - critic(
                    target_features).mean()

                loss = clf_loss + args.wd_clf * wasserstein_distance
                clf_optim.zero_grad()
                loss.backward()
                clf_optim.step()

        mean_loss = total_loss / (args.iterations * args.k_critic)
        tqdm.write(f'EPOCH {epoch:03d}: critic_loss={mean_loss:.4f}')
        torch.save(clf_model.state_dict(), 'trained_models/wdgrl.pt')
Beispiel #4
0
def main(args):
    source_model = Net().to(device)
    source_model.load_state_dict(torch.load(args.MODEL_FILE))
    source_model.eval()
    set_requires_grad(source_model, requires_grad=False)

    clf = source_model
    source_model = source_model.feature_extractor

    target_model = Net().to(device)
    target_model.load_state_dict(torch.load(args.MODEL_FILE))
    target_model = target_model.feature_extractor
    target_clf = clf.classifier

    discriminator = nn.Sequential(nn.Linear(320, 50), nn.ReLU(),
                                  nn.Linear(50, 20), nn.ReLU(),
                                  nn.Linear(20, 1)).to(device)

    half_batch = args.batch_size // 2
    source_dataset = MNIST(config.DATA_DIR / 'mnist',
                           train=True,
                           download=True,
                           transform=Compose([GrayscaleToRgb(),
                                              ToTensor()]))
    source_loader = DataLoader(source_dataset,
                               batch_size=half_batch,
                               shuffle=True,
                               num_workers=1,
                               pin_memory=True)

    target_dataset = MNISTM(train=False)
    target_loader = DataLoader(target_dataset,
                               batch_size=half_batch,
                               shuffle=True,
                               num_workers=1,
                               pin_memory=True)

    discriminator_optim = torch.optim.Adam(discriminator.parameters())
    target_optim = torch.optim.Adam(target_model.parameters())
    criterion = nn.BCEWithLogitsLoss()

    for epoch in range(1, args.epochs + 1):
        batch_iterator = zip(loop_iterable(source_loader),
                             loop_iterable(target_loader))

        total_loss = 0
        total_accuracy = 0
        target_label_accuracy = 0
        for _ in trange(args.iterations, leave=False):
            # Train discriminator
            set_requires_grad(target_model, requires_grad=False)
            set_requires_grad(discriminator, requires_grad=True)
            for _ in range(args.k_disc):
                (source_x, _), (target_x, _) = next(batch_iterator)
                source_x, target_x = source_x.to(device), target_x.to(device)

                source_features = source_model(source_x).view(
                    source_x.shape[0], -1)
                target_features = target_model(target_x).view(
                    target_x.shape[0], -1)

                discriminator_x = torch.cat([source_features, target_features])
                discriminator_y = torch.cat([
                    torch.ones(source_x.shape[0], device=device),
                    torch.zeros(target_x.shape[0], device=device)
                ])

                preds = discriminator(discriminator_x).squeeze()
                loss = criterion(preds, discriminator_y)

                discriminator_optim.zero_grad()
                loss.backward()
                discriminator_optim.step()

                total_loss += loss.item()
                total_accuracy += ((
                    preds >
                    0).long() == discriminator_y.long()).float().mean().item()

            # Train classifier
            set_requires_grad(target_model, requires_grad=True)
            set_requires_grad(discriminator, requires_grad=False)
            for _ in range(args.k_clf):
                _, (target_x, target_labels) = next(batch_iterator)
                target_x = target_x.to(device)
                target_features = target_model(target_x).view(
                    target_x.shape[0], -1)

                # flipped labels
                discriminator_y = torch.ones(target_x.shape[0], device=device)

                preds = discriminator(target_features).squeeze()
                loss = criterion(preds, discriminator_y)

                target_optim.zero_grad()
                loss.backward()
                target_optim.step()

                target_label_preds = target_clf(target_features)
                target_label_accuracy += (target_label_preds.cpu().max(1)[1] ==
                                          target_labels).float().mean().item()

        mean_loss = total_loss / (args.iterations * args.k_disc)
        mean_accuracy = total_accuracy / (args.iterations * args.k_disc)
        target_mean_accuracy = target_label_accuracy / (args.iterations *
                                                        args.k_clf)
        tqdm.write(
            f'EPOCH {epoch:03d}: discriminator_loss={mean_loss:.4f}, '
            f'discriminator_accuracy={mean_accuracy:.4f}, target_accuracy={target_mean_accuracy:.4f}'
        )

        # Create the full target model and save it
        clf.feature_extractor = target_model
        torch.save(clf.state_dict(), 'trained_models/adda.pt')
Beispiel #5
0
def main(args):
    # TODO: add DTN model
    model = Net().to(device)
    model.load_state_dict(torch.load(args.MODEL_FILE))
    feature_extractor = model.feature_extractor
    clf = model.classifier

    discriminator = nn.Sequential(GradientReversal(), nn.Linear(320, 50),
                                  nn.ReLU(), nn.Linear(50, 20), nn.ReLU(),
                                  nn.Linear(20, 1)).to(device)

    half_batch = args.batch_size // 2
    if args.adapt_setting == 'mnist2mnistm':
        source_dataset = MNIST(config.DATA_DIR / 'mnist',
                               train=True,
                               download=True,
                               transform=Compose(
                                   [GrayscaleToRgb(),
                                    ToTensor()]))
        target_dataset = MNISTM(train=False)
    elif args.adapt_setting == 'svhn2mnist':
        source_dataset = ImageClassdata(txt_file=args.src_list,
                                        root_dir=args.src_root,
                                        img_type=args.img_type,
                                        transform=transforms.Compose([
                                            transforms.Resize(28),
                                            transforms.ToTensor(),
                                        ]))
        target_dataset = ImageClassdata(txt_file=args.tar_list,
                                        root_dir=args.tar_root,
                                        img_type=args.img_type,
                                        transform=transforms.Compose([
                                            transforms.ToTensor(),
                                        ]))
    elif args.adapt_setting == 'mnist2usps':
        source_dataset = ImageClassdata(txt_file=args.src_list,
                                        root_dir=args.src_root,
                                        img_type=args.img_type,
                                        transform=transforms.Compose([
                                            transforms.ToTensor(),
                                        ]))
        target_dataset = ImageClassdata(txt_file=args.tar_list,
                                        root_dir=args.tar_root,
                                        img_type=args.img_type,
                                        transform=transforms.Compose([
                                            transforms.Resize(28),
                                            transforms.ToTensor(),
                                        ]))
    else:
        raise NotImplementedError
    source_loader = DataLoader(source_dataset,
                               batch_size=half_batch,
                               shuffle=True,
                               num_workers=1,
                               pin_memory=True,
                               drop_last=True)
    target_loader = DataLoader(target_dataset,
                               batch_size=half_batch,
                               shuffle=True,
                               num_workers=1,
                               pin_memory=True,
                               drop_last=True)

    optim = torch.optim.Adam(
        list(discriminator.parameters()) + list(model.parameters()))
    if not os.path.exists('logs'): os.makedirs('logs')
    f = open(f'logs/{args.adapt_setting}_{args.name}.txt', 'w+')

    for epoch in range(1, args.epochs + 1):
        batches = zip(source_loader, target_loader)
        n_batches = min(len(source_loader), len(target_loader))

        total_domain_loss = total_label_accuracy = 0
        target_label_accuracy = 0
        for (source_x,
             source_labels), (target_x,
                              target_labels) in tqdm(batches,
                                                     leave=False,
                                                     total=n_batches):
            x = torch.cat([source_x, target_x])
            x = x.to(device)
            domain_y = torch.cat([
                torch.ones(source_x.shape[0]),
                torch.zeros(target_x.shape[0])
            ])
            domain_y = domain_y.to(device)
            label_y = source_labels.to(device)

            features = feature_extractor(x).view(x.shape[0], -1)
            domain_preds = discriminator(features).squeeze()
            label_preds = clf(features[:source_x.shape[0]])

            domain_loss = F.binary_cross_entropy_with_logits(
                domain_preds, domain_y)
            label_loss = F.cross_entropy(label_preds, label_y)
            loss = domain_loss + label_loss

            optim.zero_grad()
            loss.backward()
            optim.step()

            total_domain_loss += domain_loss.item()
            total_label_accuracy += (
                label_preds.max(1)[1] == label_y).float().mean().item()

            target_label_preds = clf(features[source_x.shape[0]:])
            target_label_accuracy += (target_label_preds.cpu().max(1)[1] ==
                                      target_labels).float().mean().item()

        mean_loss = total_domain_loss / n_batches
        mean_accuracy = total_label_accuracy / n_batches
        target_mean_accuracy = target_label_accuracy / n_batches
        tqdm.write(
            f'EPOCH {epoch:03d}: domain_loss={mean_loss:.4f}, '
            f'source_accuracy={mean_accuracy:.4f}, target_accuracy={target_mean_accuracy:.4f}'
        )
        f.write(
            f'EPOCH {epoch:03d}: domain_loss={mean_loss:.4f}, '
            f'source_accuracy={mean_accuracy:.4f}, target_accuracy={target_mean_accuracy:.4f}\n'
        )

        torch.save(
            model.state_dict(),
            f'trained_models/{args.adapt_setting}_{args.name}_ep{epoch}.pt')
    f.close()
def main(args):
    clf_model = Net().to(device)
    clf_model.load_state_dict(torch.load(args.MODEL_FILE))

    feature_extractor = clf_model.feature_extractor
    discriminator = clf_model.classifier

    critic = nn.Sequential(nn.Linear(320, 50), nn.ReLU(), nn.Linear(50, 20),
                           nn.ReLU(), nn.Linear(20, 1)).to(device)

    half_batch = args.batch_size // 2
    if args.adapt_setting == 'mnist2mnistm':
        source_dataset = MNIST(config.DATA_DIR / 'mnist',
                               train=True,
                               download=True,
                               transform=Compose(
                                   [GrayscaleToRgb(),
                                    ToTensor()]))
        target_dataset = MNISTM(train=False)
    elif args.adapt_setting == 'svhn2mnist':
        source_dataset = ImageClassdata(txt_file=args.src_list,
                                        root_dir=args.src_root,
                                        img_type=args.img_type,
                                        transform=transforms.Compose([
                                            transforms.Resize(28),
                                            transforms.ToTensor(),
                                        ]))
        target_dataset = ImageClassdata(txt_file=args.tar_list,
                                        root_dir=args.tar_root,
                                        img_type=args.img_type,
                                        transform=transforms.Compose([
                                            transforms.ToTensor(),
                                        ]))
    elif args.adapt_setting == 'mnist2usps':
        source_dataset = ImageClassdata(txt_file=args.src_list,
                                        root_dir=args.src_root,
                                        img_type=args.img_type,
                                        transform=transforms.Compose([
                                            transforms.ToTensor(),
                                        ]))
        target_dataset = ImageClassdata(txt_file=args.tar_list,
                                        root_dir=args.tar_root,
                                        img_type=args.img_type,
                                        transform=transforms.Compose([
                                            transforms.Resize(28),
                                            transforms.ToTensor(),
                                        ]))
    else:
        raise NotImplementedError
    source_loader = DataLoader(source_dataset,
                               batch_size=half_batch,
                               shuffle=True,
                               num_workers=1,
                               pin_memory=True,
                               drop_last=True)
    target_loader = DataLoader(target_dataset,
                               batch_size=half_batch,
                               shuffle=True,
                               num_workers=1,
                               pin_memory=True,
                               drop_last=True)

    critic_optim = torch.optim.Adam(critic.parameters(), lr=1e-4)
    clf_optim = torch.optim.Adam(clf_model.parameters(), lr=1e-4)
    clf_criterion = nn.CrossEntropyLoss()

    if not os.path.exists('logs'): os.makedirs('logs')
    f = open(f'logs/{args.adapt_setting}_{args.name}.txt', 'w+')

    for epoch in range(1, args.epochs + 1):
        batch_iterator = zip(loop_iterable(source_loader),
                             loop_iterable(target_loader))

        total_loss = 0
        total_accuracy = 0
        target_label_accuracy = 0
        for _ in trange(args.iterations, leave=False):
            (source_x, source_y), (target_x, target_y) = next(batch_iterator)
            # Train critic
            set_requires_grad(feature_extractor, requires_grad=False)
            set_requires_grad(critic, requires_grad=True)

            source_x, target_x = source_x.to(device), target_x.to(device)
            source_y = source_y.to(device)

            with torch.no_grad():
                h_s = feature_extractor(source_x).data.view(
                    source_x.shape[0], -1)
                h_t = feature_extractor(target_x).data.view(
                    target_x.shape[0], -1)
            for _ in range(args.k_critic):
                gp = gradient_penalty(critic, h_s, h_t)

                critic_s = critic(h_s)
                critic_t = critic(h_t)
                wasserstein_distance = critic_s.mean() - critic_t.mean()

                critic_cost = -wasserstein_distance + args.gamma * gp

                critic_optim.zero_grad()
                critic_cost.backward()
                critic_optim.step()

                total_loss += critic_cost.item()

            # Train classifier
            set_requires_grad(feature_extractor, requires_grad=True)
            set_requires_grad(critic, requires_grad=False)
            for _ in range(args.k_clf):
                source_features = feature_extractor(source_x).view(
                    source_x.shape[0], -1)
                target_features = feature_extractor(target_x).view(
                    target_x.shape[0], -1)

                source_preds = discriminator(source_features)
                clf_loss = clf_criterion(source_preds, source_y)
                wasserstein_distance = critic(source_features).mean() - critic(
                    target_features).mean()

                loss = clf_loss + args.wd_clf * wasserstein_distance
                clf_optim.zero_grad()
                loss.backward()
                clf_optim.step()

                target_preds = discriminator(target_features)
                target_label_accuracy += (target_preds.cpu().max(1)[1] ==
                                          target_y).float().mean().item()

        mean_loss = total_loss / (args.iterations * args.k_critic)
        # mean_accuracy = total_accuracy / (args.iterations * args.k_disc)
        target_mean_accuracy = target_label_accuracy / (args.iterations *
                                                        args.k_clf)
        tqdm.write(
            f'EPOCH {epoch:03d}: critic_loss={mean_loss:.4f}, target_accuracy={target_mean_accuracy:.4f}'
        )
        f.write(
            f'EPOCH {epoch:03d}: critic_loss={mean_loss:.4f}, target_accuracy={target_mean_accuracy:.4f}'
        )
        torch.save(
            clf_model.state_dict(),
            f'trained_models/{args.adapt_setting}_{args.name}_ep{epoch}.pt')
    f.close()
def main(args):
    model = BayesNet().to(device)
    model.load_state_dict(torch.load(args.MODEL_FILE))

    discriminator = nn.Sequential(nn.Linear(10, 50), nn.ReLU(),
                                  nn.Linear(50, 20), nn.ReLU(),
                                  nn.Linear(20, 1)).to(device)

    half_batch = args.batch_size // 2
    source_dataset = MNIST(config.DATA_DIR / 'mnist',
                           train=True,
                           download=True,
                           transform=Compose([GrayscaleToRgb(),
                                              ToTensor()]))
    source_loader = DataLoader(source_dataset,
                               batch_size=half_batch,
                               shuffle=True,
                               num_workers=1,
                               pin_memory=True)

    target_dataset = MNISTM(train=False)
    target_loader = DataLoader(target_dataset,
                               batch_size=half_batch,
                               shuffle=True,
                               num_workers=1,
                               pin_memory=True)

    optim_D = torch.optim.Adam(discriminator.parameters())
    optim_G = torch.optim.Adam(model.parameters())

    for epoch in range(1, args.epochs + 1):
        batches = zip(source_loader, target_loader)
        n_batches = min(len(source_loader), len(target_loader))

        total_domain_loss = total_label_accuracy = 0
        target_label_accuracy = 0
        for (source_x,
             source_labels), (target_x,
                              target_labels) in tqdm(batches,
                                                     leave=False,
                                                     total=n_batches):
            x = torch.cat([source_x, target_x])
            x = x.to(device)
            domain_y = torch.cat([
                torch.ones(source_x.shape[0]),
                torch.zeros(target_x.shape[0])
            ])
            domain_y = domain_y.to(device)
            label_y = source_labels.to(device)

            # forward
            y_preds, s_preds = model(x)
            logits = reparameterize(y_preds, s_preds)
            domain_feat = logits

            # train discriminator
            domain_preds = discriminator(domain_feat.detach()).squeeze()
            domain_loss = F.binary_cross_entropy_with_logits(
                domain_preds, domain_y)
            optim_D.zero_grad()
            domain_loss.backward()
            optim_D.step()

            # train generator
            label_preds = logits[:source_x.shape[0]]
            label_loss = F.cross_entropy(label_preds, label_y)
            domain_preds = discriminator(
                domain_feat[:source_x.shape[0]]).squeeze()
            domain_loss = F.binary_cross_entropy_with_logits(
                domain_preds,
                torch.zeros(source_x.shape[0]).to(device))
            loss = label_loss + domain_loss
            optim_G.zero_grad()
            loss.backward()
            optim_G.step()

            total_domain_loss += domain_loss.item()
            total_label_accuracy += (
                label_preds.max(1)[1] == label_y).float().mean().item()

            target_label_preds = logits[source_x.shape[0]:]
            target_label_accuracy += (target_label_preds.cpu().max(1)[1] ==
                                      target_labels).float().mean().item()

        mean_loss = total_domain_loss / n_batches
        mean_accuracy = total_label_accuracy / n_batches
        target_mean_accuracy = target_label_accuracy / n_batches
        tqdm.write(
            f'EPOCH {epoch:03d}: domain_loss={mean_loss:.4f}, '
            f'source_accuracy={mean_accuracy:.4f}, target_accuracy={target_mean_accuracy:.4f}'
        )

        torch.save(model.state_dict(), 'trained_models/auda.pt')