def main():
    args = parse_args()

    # load real test data
    test_loader = load_data(args, train=False)

    # model
    checkpoint_path = join(os.getcwd(), args.model_dir, args.model_name)
    if args.inception:
        model = inception_345().cuda()
    else:
        model = resnet_345(args).cuda()

    if args.fixed_pretrained:
        fixed_pretrained(model)

    params_to_update = []
    for name, param in model.named_parameters():
        if param.requires_grad:
            params_to_update.append(param)
    print('Number of parameters to update: {}'.format(len(params_to_update)))

    optimizer = optim.SGD(params_to_update, lr=args.lr, momentum=0.9)
    load_checkpoint(checkpoint_path, model, optimizer)

    if args.real_test:
        print('model:{}, target:{}'.format(args.model_name, args.target))
        test_real(model, test_loader, args,
                  '{}_result.csv'.format(splitext(args.model_name)[0]))
    else:
        print('model:{}, target:{}'.format(args.model_name, args.target))
        test(model, test_loader, args)
def test_model(args):
    torch.manual_seed(args.manual_seed)

    dataroot = join(os.getcwd(), args.data_dir)
    transform = transforms.Compose([
        # transforms.RandomCrop(args.img_size, padding=None, pad_if_needed=True, fill=0, padding_mode='edge'),
        transforms.Resize((args.img_size, args.img_size)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    target_dataset = dataset_public(
        root=dataroot,
        transform=transform,
        train=False,
        domain=args.target,
    )
    target_loader = torch.utils.data.DataLoader(
        dataset=target_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True,
        drop_last=False,
    )

    model = resnet_345(args).cuda()
    model = load_model(args.model_path, model)

    preds = []
    labels = []
    model.eval()
    test_pbar = tqdm(total=len(target_loader), ncols=100, leave=True)
    for batch_idx, (x_inputs, y_labels) in enumerate(target_loader):
        x_inputs = x_inputs.cuda()

        with torch.no_grad():
            output = model.predict(x_inputs)
        _, pred = torch.max(output, 1)
        preds += pred.cpu().tolist()
        labels += y_labels.tolist()

        test_pbar.update()

    test_pbar.close()

    preds = np.array(preds)
    labels = np.array(labels)
    acc = float(sum(preds == labels)) / float(len(preds))
    print('valid acc = {:4f}'.format(acc))

    return acc
def test_model_create_output(args):
    torch.manual_seed(args.manual_seed)

    dataroot = args.data_dir
    transform = transforms.Compose([
        # transforms.RandomCrop(args.img_size, padding=None, pad_if_needed=True, fill=0, padding_mode='edge'),
        transforms.Resize((args.img_size, args.img_size)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    target_dataset = dataset_public(
        root=dataroot,
        transform=transform,
        train=False,
        domain=args.target,
    )
    target_loader = torch.utils.data.DataLoader(
        dataset=target_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True,
        drop_last=False,
    )

    model = resnet_345(args).cuda()
    model = load_model(args.model_path, model)

    preds = []
    file_names = []
    model.eval()
    test_pbar = tqdm(total=len(target_loader), ncols=100, leave=True)
    for batch_idx, (x_inputs, y_file_names) in enumerate(target_loader):
        x_inputs = x_inputs.cuda()

        with torch.no_grad():
            output = model.predict(x_inputs)
        _, pred = torch.max(output, 1)
        preds += pred.cpu().tolist()
        file_names += y_file_names

        test_pbar.update()

    test_pbar.close()

    with open(join(args.pred_dir, 'output_{}.csv'.format(args.target)),
              'w') as f:
        f.write('image_name,label\n')
        for i, pred in enumerate(preds):
            f.write('{}/{},{}\n'.format(args.title, file_names[i], pred))
def main():
    args = parse_args()

    torch.manual_seed(args.manual_seed)

    dataroot = join(os.getcwd(), args.data_dir)
    transform=transforms.Compose([
              # transforms.RandomCrop(args.img_size, padding=None, pad_if_needed=True, fill=0, padding_mode='constant'),
              transforms.Resize((args.img_size,args.img_size)),
              transforms.ToTensor(),
              transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    train_dataset = dataset_public(root=dataroot,
                             transform=transform,
                             train=True,
                             domains=args.source,
    )
    train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True,
        drop_last=False,
        )
    if args.test:
        transform=transforms.Compose([
                  # transforms.RandomCrop(args.img_size, padding=None, pad_if_needed=True, fill=0, padding_mode='edge'),
                  transforms.Resize((args.img_size,args.img_size)),
                  transforms.ToTensor(),
                  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])
        test_dataset = dataset_public(root=dataroot,
                                 transform=transform,
                                 train=False,
                                 domains=args.target,
                                 real_test=args.real_test,
        )
        test_loader = torch.utils.data.DataLoader(
            dataset=test_dataset,
            batch_size=args.batch_size,
            shuffle=False,
            num_workers=args.num_workers,
            pin_memory=True,
            drop_last=False,
            )

    # models
    if args.inception:
        cls = inception_345().cuda()
    else:
        cls = resnet_345(args).cuda()

    if args.fixed_pretrained:
        fixed_pretrained(cls)

    params_to_update = []
    for name, param in cls.named_parameters():
        if param.requires_grad:
            params_to_update.append(param)
    print('Number of parameters to update: {}'.format(len(params_to_update)))

    # optimizer
    if args.SGD:
        optimizer = optim.SGD(params_to_update,  lr=args.lr, momentum=0.9)
    elif args.Adam:
        optimizer = optim.Adam(params_to_update, lr=args.lr)

    if args.resume:
        checkpoint_path = join(os.getcwd(), args.model_dir, args.model_name)
        load_checkpoint(checkpoint_path, cls, optimizer)

    best_acc = 0
    patience = 0
    if len(args.source) > 1:
        err_log_path = join(os.getcwd(), 'models_inception', 'src_combine_tar_{}_err.txt'.format(args.target[0]))
    else:
        err_log_path = join(os.getcwd(), 'models_inception', 'src_{}_err.txt'.format(args.source[0]))

    err_log = open(err_log_path, 'w')
    for epoch in range(args.start_epoch, args.max_epochs+1):
        criterion = nn.CrossEntropyLoss()
        cls.train()

        print ('\nEpoch = {}'.format(epoch))
        err_log.write('Epoch = {}, '.format(epoch))

        losses, train_acc = AverageMeter(), AverageMeter()
        train_pbar = tqdm(total=len(train_loader), ncols=100, leave=True)
        for i, (images, labels) in enumerate(train_loader):

            images, labels = images.cuda(), labels.cuda()

            cls.zero_grad()

            if args.inception:
                class_output, class_aux_output = cls(images)
                loss1 = criterion(class_output, labels)
                loss2 = criterion(class_aux_output, labels)
                loss = loss1 + 0.4*loss2
                loss.backward()

            else:
                class_output = cls(images)
                loss = criterion(class_output, labels)
                loss.backward()

            pred = class_output.max(1, keepdim=True)[1]
            correct = pred.eq(labels.view_as(pred)).sum().item()

            optimizer.step()

            losses.update(loss.data.item(), args.batch_size)
            train_acc.update(correct, args.batch_size)
            train_pbar.update()

            train_pbar.set_postfix({'loss':'{:.4f}'.format(losses.avg),
                                    'acc':'{:.4f}'.format(train_acc.acc),
                                    })

        train_pbar.close()
        if args.test:
            test_acc, test_loss = test(cls, test_loader, args)

            if test_acc.acc > best_acc:
                best_acc = test_acc.acc
                patience = 0
                if len(args.source) > 1:
                    checkpoint_path = join(os.getcwd(), 'models_inception', 'src_combine_tar_{}.pth'.format(args.target[0]))
                else:
                    checkpoint_path = join(os.getcwd(), 'models_inception', 'src_{}.pth'.format(args.source[0]))
                save_checkpoint(checkpoint_path, cls, optimizer)
            else:
                patience += 1

            err_log.write('Loss: {:.4f}/{:.4f}, Accuracy: {:.4f}/{:.4f}\n'.format(losses.avg, test_loss.avg,
                          train_acc.acc, test_acc.acc))
        else:
            if train_acc.acc > best_acc:
                best_acc = train_acc.acc
                if len(args.source) > 1:
                    checkpoint_path = join(os.getcwd(), 'models', 'src_combine_tar_{}.pth'.format(args.target[0]))
                else:
                    checkpoint_path = join(os.getcwd(), 'models', 'src_{}.pth'.format(args.source[0]))
                save_checkpoint(checkpoint_path, cls, optimizer)

            err_log.write('Loss: {:.4f}, Accuracy: {:.4f}\n'.format(losses.avg, train_acc.acc))

        err_log.flush()

        if patience == args.early_stopping:
            break

    err_log.write('Best test_acc: {:.4f}\n'.format(best_acc))
    err_log.close()
def main():
    args = parse_args()

    torch.manual_seed(args.manual_seed)

    dataroot = join(os.getcwd(), args.data_dir)
    transform = transforms.Compose([
        # transforms.RandomCrop(args.img_size, padding=None, pad_if_needed=True, fill=0, padding_mode='edge'),
        transforms.Resize((args.img_size, args.img_size)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    src_dataset = dataset_public(
        root=dataroot,
        transform=transform,
        train=True,
        domains=args.source,
    )
    src_loader = torch.utils.data.DataLoader(
        dataset=src_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True,
        drop_last=True,
    )
    tar_dataset = dataset_public(
        root=dataroot,
        transform=transform,
        train=True,
        domains=args.target,
    )
    tar_loader = torch.utils.data.DataLoader(
        dataset=tar_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True,
        drop_last=True,
    )

    # models
    src_CNN = resnet_345(args).cuda()
    tar_CNN = resnet_345(args).cuda()
    D_cls = Domain_classifier().cuda()

    # load and fix pretrained models
    resume_path = join(os.getcwd(), args.model_dir, args.model_name)
    resume_pretrained(resume_path, src_CNN)
    resume_pretrained(resume_path, tar_CNN)

    fixed_pretrained(src_CNN, 'CNN')
    fixed_pretrained(src_CNN, 'fc')
    fixed_pretrained(tar_CNN, 'fc')

    # optimizer
    if args.SGD:
        optimizer_tar_CNN = optim.SGD(tar_CNN.parameters(),
                                      lr=args.lr,
                                      momentum=0.9)
        optimizer_D_cls = optim.SGD(D_cls.parameters(),
                                    lr=args.lr,
                                    momentum=0.9)
    elif args.Adam:
        optimizer_tar_CNN = optim.Adam(tar_CNN.parameters(), lr=args.lr)
        optimizer_D_cls = optim.Adam(D_cls.parameters(), lr=args.lr)

    # domain labels
    src_label = torch.full((args.batch_size, ), 0).long().cuda()
    tar_label = torch.full((args.batch_size, ), 1).long().cuda()

    min_len = min(len(src_loader), len(tar_loader))
    best_acc = 0
    patience = 0
    err_log_path = join(os.getcwd(), 'models',
                        'adda_{}_err.txt'.format(args.target[0]))
    err_log = open(err_log_path, 'w')
    for epoch in range(args.start_epoch, args.max_epochs + 1):
        criterion = nn.CrossEntropyLoss()
        src_CNN.train()
        tar_CNN.eval()
        D_cls.train()

        print('\nEpoch = {}'.format(epoch))
        err_log.write('Epoch = {}, '.format(epoch))

        losses, train_acc, D_losses, D_acc = AverageMeter(), AverageMeter(
        ), AverageMeter(), AverageMeter()
        train_pbar = tqdm(total=min_len, ncols=100, leave=True)
        for i, (src_data, tar_data) in enumerate(zip(src_loader, tar_loader)):

            src_imgs, _ = src_data
            tar_imgs, tar_labels = tar_data

            src_imgs, tar_imgs, tar_labels = src_imgs.cuda(), tar_imgs.cuda(
            ), tar_labels.cuda()

            # train D_cls
            D_cls.zero_grad()

            src_feature = src_CNN(src_imgs)
            tar_feature = tar_CNN(tar_imgs)

            pred_src_domain = D_cls(src_feature.detach())
            pred_tar_domain = D_cls(tar_feature.detach())
            domain_output = torch.cat((pred_src_domain, pred_tar_domain), 0)
            label = torch.cat((src_label, tar_label), 0)
            domain_output = torch.squeeze(domain_output)

            D_loss = criterion(domain_output, label)
            D_loss.backward()

            optimizer_D_cls.step()

            # domain accuracy
            _, pred = torch.max(domain_output, 1)
            D_correct = sum(pred == label).float() / float(pred.shape[0])

            # train tar_CNN
            tar_CNN.zero_grad()

            tar_feature = tar_CNN(tar_imgs)

            pred_tar_domain = D_cls(tar_feature)
            pred_tar_domain = torch.squeeze(pred_tar_domain)

            loss = criterion(pred_tar_domain, src_label)
            loss.backward()

            optimizer_tar_CNN.step()

            # predict accuracy
            class_output = tar_CNN.predict(tar_imgs)
            _, pred = torch.max(class_output, 1)
            cls_correct = sum(pred == tar_labels).float() / float(
                pred.shape[0])

            # update losses, accuracy and pbar
            D_losses.update(D_loss.data.item(), args.batch_size * 2)
            D_acc.update(D_correct, args.batch_size * 2)
            losses.update(loss.data.item(), args.batch_size)
            train_acc.update(cls_correct, args.batch_size)
            train_pbar.update()

            train_pbar.set_postfix({
                'D_loss': '{:.3f}'.format(D_losses.avg),
                'D_acc': '{:.3f}'.format(D_acc.avg),
                'loss': '{:.3f}'.format(losses.avg),
                'acc': '{:.3f}'.format(train_acc.avg),
            })
        train_pbar.close()

        test_acc, test_loss = test(tar_CNN, tar_loader, args.batch_size)
        if test_acc.avg > best_acc:
            best_acc = test_acc.avg
            patience = 0
            checkpoint_path = join(os.getcwd(), 'models',
                                   'adda_{}.pth'.format(args.target[0]))
            save_checkpoint(checkpoint_path, tar_CNN, D_cls, optimizer_tar_CNN,
                            optimizer_D_cls)
        else:
            patience += 1

        err_log.write('Loss: {:.4f}, Accuracy: {:.4f}\n'.format(
            losses.avg, test_acc.avg))
        err_log.flush()

        if patience >= args.early_stopping:
            print('Early stopping...')
            break

    err_log.write('Best test_acc: {:.4f}\n'.format(best_acc))
    err_log.close()
Esempio n. 6
0
def main():
    args = parse_args()

    torch.manual_seed(args.manual_seed)

    dataroot = join(os.getcwd(), args.data_dir)
    transform = transforms.Compose([
        # transforms.RandomCrop(args.img_size, padding=None, pad_if_needed=True, fill=0, padding_mode='edge'),
        transforms.Resize((args.img_size, args.img_size)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    src_dataset = dataset_public(
        root=dataroot,
        transform=transform,
        train=True,
        domains=args.source,
    )
    src_loader = torch.utils.data.DataLoader(
        dataset=src_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True,
        drop_last=False,
    )
    tar_dataset = dataset_public(
        root=dataroot,
        transform=transform,
        train=True,
        domains=args.target,
    )
    tar_loader = torch.utils.data.DataLoader(
        dataset=tar_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True,
        drop_last=False,
    )

    # models
    src_CNN = resnet_345(args).cuda()
    tar_CNN = resnet_345(args).cuda()
    cls = classifier().cuda()

    checkpoint_path = join(os.getcwd(), args.model_dir, args.model_name)
    load_checkpoint(checkpoint_path, src_CNN)
    load_checkpoint(checkpoint_path, tar_CNN)
    load_checkpoint(checkpoint_path, cls)

    fixed_pretrained(src_CNN)
    fixed_pretrained(cls)

    # optimizer
    if args.SGD:
        optimizer = optim.SGD(tar_CNN.parameters(), lr=args.lr, momentum=0.9)
    elif args.Adam:
        optimizer = optim.Adam(tar_CNN.parameters(), lr=args.lr)

    min_len = min(len(src_loader), len(tar_loader))
    best_acc = 0
    stop_count = 0
    err_log_path = join(os.getcwd(), 'models',
                        'adda_{}_err.txt'.format(args.target[0]))
    err_log = open(err_log_path, 'w')
    for epoch in range(args.start_epoch, args.max_epochs + 1):
        src_CNN.eval()
        tar_CNN.train()
        cls.eval()

        print('\nEpoch = {}'.format(epoch))
        err_log.write('Epoch = {}, '.format(epoch))

        losses, train_acc = AverageMeter(), AverageMeter()
        train_pbar = tqdm(total=min_len, ncols=100, leave=True)
        for i, (src_data, tar_data) in enumerate(zip(src_loader, tar_loader)):
            src_imgs, _ = src_data
            tar_imgs, tar_labels = tar_data
            # src_imgs, src_labels = src_data

            src_imgs, tar_imgs, tar_labels = src_imgs.cuda(), tar_imgs.cuda(
            ), tar_labels.cuda()
            # src_imgs, tar_imgs, tar_labels, src_labels = src_imgs.cuda(), tar_imgs.cuda(), tar_labels.cuda(), src_labels.cuda()

            tar_CNN.zero_grad()
            src_feature = src_CNN(src_imgs)
            tar_feature = tar_CNN(tar_imgs)

            loss = F.mse_loss(src_feature, tar_feature, reduction='mean')
            loss.backward()

            class_output = cls(tar_feature)
            pred = class_output.max(1, keepdim=True)[1]
            correct = pred.eq(tar_labels.view_as(pred)).sum().item()

            optimizer.step()

            losses.update(loss.data.item(), args.batch_size)
            train_acc.update(correct, args.batch_size)
            train_pbar.update()

            train_pbar.set_postfix({
                'loss': '{:.4f}'.format(losses.avg),
                'acc': '{:.4f}'.format(train_acc.acc),
            })
        train_pbar.close()

        if train_acc.acc > best_acc:
            best_acc = train_acc.acc
            stop_count = 0
            checkpoint_path = join(os.getcwd(), 'models',
                                   'adda_{}.pth'.format(args.target[0]))
            save_checkpoint(checkpoint_path, tar_CNN, cls, optimizer)
        else:
            stop_count += 1

        err_log.write('Loss: {:.4f}, Accuracy: {:.4f}\n'.format(
            losses.avg, train_acc.acc))
        err_log.flush()

        if stop_count == args.early_stopping: break

    err_log.write('Best test_acc: {:.4f}\n'.format(best_acc))
    err_log.close()
def train_model(args):
    torch.manual_seed(args.manual_seed)

    dataroot = join(os.getcwd(), args.data_dir)
    transform = transforms.Compose([
        # transforms.RandomCrop(args.img_size, padding=None, pad_if_needed=True, fill=0, padding_mode='edge'),
        transforms.Resize((args.img_size, args.img_size)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    source_dataset, source_loader = [], []
    for source in args.source:
        dataset = dataset_public(
            root=dataroot,
            transform=transform,
            train=True,
            domain=source,
        )
        loader = torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.num_workers,
            pin_memory=True,
            drop_last=True,
        )

        source_dataset.append(dataset)
        source_loader.append(loader)

    target_dataset = dataset_public(
        root=dataroot,
        transform=transform,
        train=True,
        domain=args.target,
    )
    target_loader = torch.utils.data.DataLoader(
        dataset=target_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True,
        drop_last=True,
    )

    # models
    cls = resnet_345(args).cuda()

    resume_path = join(os.getcwd(), args.model_dir, args.model_name)
    resume_pretrained(resume_path, cls)

    if args.fixed_pretrained:
        fixed_pretrained(cls)

    # print('Training parameters')
    params_to_update = []
    for name, param in cls.named_parameters():
        if param.requires_grad:
            params_to_update.append(param)
            # print(name)
    print('Number of parameters to update: {}'.format(len(params_to_update)))

    # optimizer
    if args.SGD:
        optimizer = optim.SGD(params_to_update, lr=args.lr, momentum=0.9)
    elif args.Adam:
        optimizer = optim.Adam(params_to_update, lr=args.lr)

    best_acc = 0.0
    patience = 0
    f = open(join('./logs', 'log_{}.txt'.format(args.target)), 'w')
    f.write('cls_loss,md_loss,acc\n')
    for epoch in range(args.start_epoch, args.max_epochs + 1):
        criterion = M3SDA_Loss()
        cls.train()

        print('\nEpoch = {}'.format(epoch))

        cls_losses, md_losses, train_acc = AverageMeter(), AverageMeter(
        ), AverageMeter()
        len_train_loader = min(
            [len(loader) for loader in (source_loader + [target_loader])])
        train_pbar = tqdm(total=len_train_loader, ncols=100, leave=True)
        for i, batch_data in enumerate(
                zip(source_loader[0], source_loader[1], source_loader[2],
                    target_loader)):
            data_1, data_2, data_3, data_t = batch_data
            input_1, label_1 = data_1
            input_2, label_2 = data_2
            input_3, label_3 = data_3
            input_t, label_t = data_t

            input_1, input_2, input_3, input_t = input_1.cuda(), input_2.cuda(
            ), input_3.cuda(), input_t.cuda()
            label_1, label_2, label_3, label_t = label_1.cuda(), label_2.cuda(
            ), label_3.cuda(), label_t.cuda()

            cls.zero_grad()

            e1, e2, e3, et, out_1, out_2, out_3 = cls(input_1, input_2,
                                                      input_3, input_t)
            cls_loss, md_loss = criterion(e1, e2, e3, et, out_1, out_2, out_3,
                                          label_1, label_2, label_3)
            loss = cls_loss + md_loss
            loss.backward()
            optimizer.step()

            output = torch.cat((out_1, out_2, out_3), 0)
            labels = torch.cat((label_1, label_2, label_3))
            _, pred = torch.max(output, 1)
            acc = sum(pred == labels).float() / float(pred.shape[0])

            cls_losses.update(cls_loss.data.item(), input_1.shape[0])
            md_losses.update(md_loss.data.item(), input_1.shape[0])
            train_acc.update(acc, 1)

            train_pbar.update()
            train_pbar.set_postfix({
                'cls_loss': '{:.4f}'.format(cls_losses.avg),
                'md_loss': '{:.4f}'.format(md_losses.avg),
                'acc': '{:.4f}'.format(train_acc.avg)
            })

        train_pbar.close()
        if epoch % 1 == 0:
            acc, _ = eval_model(cls, target_loader)
            f.write('{:4f},{:4f},{:4f}\n'.format(cls_losses.avg, md_losses.avg,
                                                 acc))
            f.flush()
            if acc > best_acc:
                best_acc = acc
                save_model(
                    join('./models',
                         'm3sda_{}_{}_{}.pth'.format(args.target, epoch, acc)),
                    cls)
                patience = 0
            else:
                patience += 1

            if patience >= args.early_stop:
                print('early stopping...')
                break
                dataset=dataset,
                batch_size=128,
                shuffle=True,
                num_workers=args.num_workers,
                pin_memory=True,
                drop_last=True,
            )

            source_dataset.append(dataset)
            source_loader.append(loader)

        target_dataset = dataset_public(
            root=dataroot,
            transform=transform,
            train=True,
            domain=args.target,
        )
        target_loader = DataLoader(
            dataset=target_dataset,
            batch_size=128,
            shuffle=True,
            num_workers=args.num_workers,
            pin_memory=True,
            drop_last=True,
        )

        model = resnet_345(args).cuda()
        model = load_model(args.model_path, model)

        draw_tsne(model, source_loader, target_loader, args)