예제 #1
0
def main(config):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5, 0.5, 0.5))
    ])
    use_cuda = torch.cuda.is_available()
    device = torch.device('cuda' if use_cuda else 'cpu')

    model = DANN().cuda()

    state = torch.load(config.ckp_path)
    model.load_state_dict(state['state_dict'])

    filenames = glob.glob(os.path.join(config.img_dir, '*.png'))
    filenames = sorted(filenames)

    out_filename = config.save_path
    os.makedirs(os.path.dirname(config.save_path), exist_ok=True)

    model.eval()
    with open(out_filename, 'w') as out_file:
        out_file.write('image_name,label\n')
        with torch.no_grad():
            for fn in filenames:
                data = Image.open(fn).convert('RGB')
                data = transform(data)
                data = torch.unsqueeze(data, 0)
                data = data.cuda()
                output, _ = model(data, 1)
                pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
                out_file.write(fn.split('/')[-1] + ',' + str(pred.item()) + '\n')
예제 #2
0
        item_pr = 'Epoch: [{}/{}], classify_loss: {:.4f}, domain_loss_s: {:.4f}, domain_loss_t: {:.4f}, domain_loss: {:.4f},total_loss: {:.4f}'.format(
            epoch, args.nepoch, err_s_label.item(), err_s_domain.item(),
            err_t_domain.item(), err_domain.item(), err.item())
        print(item_pr)
        fp = open(args.result_path, 'a')
        fp.write(item_pr + '\n')

        # test
        acc_src = test(model, args.source, epoch)
        acc_tar = test(model, args.target, epoch)
        test_info = 'Source acc: {:.4f}, target acc: {:.4f}'.format(
            acc_src, acc_tar)
        fp.write(test_info + '\n')
        print(test_info)
        fp.close()

        if best_acc < acc_tar:
            best_acc = acc_tar
            if not os.path.exists(args.model_path):
                os.mkdirs(args.model_path)
            torch.save(model, '{}/mnist_mnistm.pth'.format(args.model_path))
    print('Test acc: {:.4f}'.format(best_acc))


if __name__ == '__main__':
    torch.random.manual_seed(10)
    loader_src, loader_tar = data_loader.load_data()
    model = DANN(DEVICE).to(DEVICE)
    optimizer = optim.SGD(model.parameters(), lr=args.lr)
    train(model, optimizer, loader_src, loader_tar)
예제 #3
0
                                        IMAGE_DIM,
                                        DIR)
max_batches = min(len(source_dataloader), len(target_dataloader))

# Instantiate Dicts
i = 0
labels_loss = []
source_domain_loss = []
target_domain_loss = []
batch = []

# Instantiate Tensor Board Writer
writer = SummaryWriter()

# Instantiate Model
model = DANN(image_dim=28)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# Load Optimizer and Loss Functions
optimizer = optim.Adam(model.parameters(), LR)

# instantiate 2 loss functions
class_loss_criterion = nn.CrossEntropyLoss()
domain_loss_criterion = nn.CrossEntropyLoss()

for epoch in range(EPOCHS):
    print(f'Epoch {epoch + 1:04d} / {EPOCHS:04d}', end='\n=================\n')
    source_domain_iterator = iter(source_dataloader)
    target_domain_iterator = iter(target_dataloader)
예제 #4
0
 def build_model(self):
     self.model = DANN().to(self.device)
     self.optimizer = torch.optim.Adam(self.model.parameters(),
                                       lr=self.lr,
                                       betas=[self.beta1, self.beta2])
예제 #5
0

def fit_with_adaptation():
    pass


def eval(model, val_dataloader):
    model.eval()
    pass


if __name__ == '__main__':

    transformation = transforms.Compose([
        transforms.Resize(230),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        # transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) # Normalizes tensor with mean and standard deviation of IMAGENET
    ])

    pacs = PACS(root=PACS_PATH,
                transform=transformation,
                num_workers=4,
                batch_size=BATCH_SIZE)
    dann = DANN(pretrained=True, num_domains=2, num_classes=7)

    fit_simple(dann,
               epochs=5,
               train_dataloader=pacs[SOURCE_DOMAIN],
               val_dataloader=pacs[TARGET_DOMAIN])