示例#1
0
def main():
	model_fn = "./model.pth"
	device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

	x, y = load_mnist(is_train=True, flatten=True)
	x, y = x.to(device), y.to(device)

	model = ImageClassifier(28 * 28, 10).to(device)
	model.load_state_dict(load(model_fn, device))

	test(model, x[:20], y[:20], to_be_shown=True)
def main(args: argparse.Namespace):
    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True

    cudnn.benchmark = True

    # Data loading code
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_transform = transforms.Compose([
        ResizeImage(256),
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(), normalize
    ])
    val_tranform = transforms.Compose([
        ResizeImage(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(), normalize
    ])

    a, b, c = args.n_share, args.n_source_private, args.n_total
    common_classes = [i for i in range(a)]
    source_private_classes = [i + a for i in range(b)]
    target_private_classes = [i + a + b for i in range(c - a - b)]
    source_classes = common_classes + source_private_classes
    target_classes = common_classes + target_private_classes

    dataset = datasets.Office31
    train_source_dataset = dataset(root=args.root,
                                   data_list_file=args.source,
                                   filter_class=source_classes,
                                   transform=train_transform)
    train_source_loader = DataLoader(train_source_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     drop_last=True)
    train_target_dataset = dataset(root=args.root,
                                   data_list_file=args.target,
                                   filter_class=target_classes,
                                   transform=train_transform)
    train_target_loader = DataLoader(train_target_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     drop_last=True)
    val_dataset = dataset(root=args.root,
                          data_list_file=args.target,
                          filter_class=target_classes,
                          transform=val_tranform)
    val_loader = DataLoader(val_dataset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.workers)

    test_loader = val_loader

    train_source_iter = ForeverDataIterator(train_source_loader)
    train_target_iter = ForeverDataIterator(train_target_loader)
    esem_iter1, esem_iter2, esem_iter3, esem_iter4, esem_iter5 = esem_dataloader(
        args, source_classes)

    # create model
    backbone = resnet50(pretrained=True)
    classifier = ImageClassifier(backbone,
                                 train_source_dataset.num_classes).to(device)
    domain_discri = DomainDiscriminator(in_feature=classifier.features_dim,
                                        hidden_size=1024).to(device)
    esem = Ensemble(classifier.features_dim,
                    train_source_dataset.num_classes).to(device)

    # define optimizer and lr scheduler
    optimizer = SGD(classifier.get_parameters() +
                    domain_discri.get_parameters(),
                    args.lr,
                    momentum=args.momentum,
                    weight_decay=args.weight_decay,
                    nesterov=True)
    lr_scheduler = StepwiseLR(optimizer,
                              init_lr=args.lr,
                              gamma=0.001,
                              decay_rate=0.75)

    optimizer_esem = SGD(esem.parameters(),
                         args.lr,
                         momentum=args.momentum,
                         weight_decay=args.weight_decay,
                         nesterov=True)
    lr_scheduler1 = StepwiseLR(optimizer_esem,
                               init_lr=args.lr,
                               gamma=0.001,
                               decay_rate=0.75)
    lr_scheduler2 = StepwiseLR(optimizer_esem,
                               init_lr=args.lr,
                               gamma=0.001,
                               decay_rate=0.75)
    lr_scheduler3 = StepwiseLR(optimizer_esem,
                               init_lr=args.lr,
                               gamma=0.001,
                               decay_rate=0.75)
    lr_scheduler4 = StepwiseLR(optimizer_esem,
                               init_lr=args.lr,
                               gamma=0.001,
                               decay_rate=0.75)
    lr_scheduler5 = StepwiseLR(optimizer_esem,
                               init_lr=args.lr,
                               gamma=0.001,
                               decay_rate=0.75)

    optimizer_pre = SGD(esem.get_parameters() + classifier.get_parameters(),
                        args.lr,
                        momentum=args.momentum,
                        weight_decay=args.weight_decay,
                        nesterov=True)

    # define loss function
    domain_adv = DomainAdversarialLoss(domain_discri,
                                       reduction='none').to(device)

    pretrain(esem_iter1, esem_iter2, esem_iter3, esem_iter4, esem_iter5,
             classifier, esem, optimizer_pre, args)

    # start training
    best_acc1 = 0.
    for epoch in range(args.epochs):
        # train for one epoch

        train_esem(esem_iter1,
                   classifier,
                   esem,
                   optimizer_esem,
                   lr_scheduler1,
                   epoch,
                   args,
                   index=1)
        train_esem(esem_iter2,
                   classifier,
                   esem,
                   optimizer_esem,
                   lr_scheduler2,
                   epoch,
                   args,
                   index=2)
        train_esem(esem_iter3,
                   classifier,
                   esem,
                   optimizer_esem,
                   lr_scheduler3,
                   epoch,
                   args,
                   index=3)
        train_esem(esem_iter4,
                   classifier,
                   esem,
                   optimizer_esem,
                   lr_scheduler4,
                   epoch,
                   args,
                   index=4)
        train_esem(esem_iter5,
                   classifier,
                   esem,
                   optimizer_esem,
                   lr_scheduler5,
                   epoch,
                   args,
                   index=5)

        source_class_weight = evaluate_source_common(val_loader, classifier,
                                                     esem, source_classes,
                                                     args)

        train(train_source_iter, train_target_iter, classifier, domain_adv,
              esem, optimizer, lr_scheduler, epoch, source_class_weight, args)

        # evaluate on validation set
        acc1 = validate(val_loader, classifier, esem, source_classes, args)

        # remember best acc@1 and save checkpoint
        if acc1 > best_acc1:
            best_model = copy.deepcopy(classifier.state_dict())
        best_acc1 = max(acc1, best_acc1)

    print("best_acc1 = {:3.3f}".format(best_acc1))

    # evaluate on test set
    classifier.load_state_dict(best_model)
    acc1 = validate(test_loader, classifier, esem, source_classes, args)
    print("test_acc1 = {:3.3f}".format(acc1))