示例#1
0
                                  augment_func=Duplicate)

    if args.adv:
        model, teacher, disc, kwargs = experiment_setup(args)
        experiment = solver.AdversarialDropoutSolver(model, loader_plain,
                                                     **kwargs)
        experiment.optimize()

    if args.vada:
        model, teacher, disc, kwargs = experiment_setup(args)
        experiment = solver.VADASolver(model, disc, loader_plain, **kwargs)
        experiment.optimize()

    if args.dann:
        model, teacher, disc, kwargs = experiment_setup(args)
        experiment = solver.DANNSolver(model, disc, loader_plain, **kwargs)
        experiment.optimize()

    if args.assoc:
        model, teacher, disc, kwargs = experiment_setup(args)
        experiment = solver.AssociativeSolver(model, loader_plain, **kwargs)
        experiment.optimize()

    if args.coral:
        print(args.null)
        model, teacher, disc, kwargs = experiment_setup(args)
        experiment = solver.DeepCoralSolver(model,
                                            loader_plain,
                                            use_nullspace=args.null,
                                            **kwargs)
        experiment.optimize()
示例#2
0
文件: train_dann.py 项目: zhmd/salad
    if osp.exists(args.checkpoint):
        print("Resume from checkpoint file at {}".format(args.checkpoint))
        model = torch.load(args.checkpoint)
    else:
        model = models.SVHNmodel()

    # Dataset
    data = datasets.load_dataset(path="data", train=True, img_size=32)

    train_loader = torch.utils.data.DataLoader(data[args.source],
                                               batch_size=args.sourcebatch,
                                               shuffle=True,
                                               num_workers=4)
    val_loader = torch.utils.data.DataLoader(data[args.target],
                                             batch_size=args.targetbatch,
                                             shuffle=True,
                                             num_workers=4)

    dataset = datasets.JointLoader(train_loader, val_loader)

    # Initialize the solver for this experiment
    experiment = solver.DANNSolver(model,
                                   dataset,
                                   n_epochs=args.epochs,
                                   savedir=args.log,
                                   dryrun=args.dryrun,
                                   learningrate=args.learningrate,
                                   gpu=args.gpu if not args.cpu else None)

    experiment.optimize()