예제 #1
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(
        description='PyTorch classification example')
    parser.add_argument('--dataset',
                        type=str,
                        help='dataset',
                        choices=[
                            'mnist',
                            'usps',
                            'svhn',
                            'syn_digits',
                            'imagenet32x32',
                            'cifar10',
                            'stl10',
                        ])
    parser.add_argument('--arch', type=str, help='network architecture')
    parser.add_argument('--batch_size',
                        type=int,
                        default=128,
                        metavar='N',
                        help='input batch size for training (default: 128)')
    parser.add_argument('--test_batch_size',
                        type=int,
                        default=1000,
                        metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=10,
                        metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--val_ratio',
                        type=float,
                        default=0.0,
                        help='sampling ratio of validation data')
    parser.add_argument('--train_ratio',
                        type=float,
                        default=1.0,
                        help='sampling ratio of training data')
    parser.add_argument('--lr',
                        type=float,
                        default=0.01,
                        metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--wd',
                        type=float,
                        default=1e-6,
                        help='weight_decay (default: 1e-6)')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.9,
                        metavar='M',
                        help='SGD momentum (default: 0.9)')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument(
        '--log_interval',
        type=int,
        default=10,
        metavar='N',
        help='how many batches to wait before logging training status')
    parser.add_argument('--output_path',
                        type=str,
                        help='path to save ckpt and log. ')
    parser.add_argument('--resume',
                        type=str,
                        help='resume training from ckpt path')
    parser.add_argument('--ckpt_file', type=str, help='init model from ckpt. ')
    parser.add_argument(
        '--exclude_vars',
        type=str,
        help=
        'prefix of variables not restored form ckpt, seperated with commas; valid if ckpt_file is not None'
    )
    parser.add_argument('--imagenet_pretrain',
                        action='store_true',
                        help='use pretrained imagenet model')
    args = parser.parse_args()
    use_cuda = torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    if args.output_path is not None and not os.path.exists(args.output_path):
        os.makedirs(args.output_path)

    writer = SummaryWriter(args.output_path)

    use_normalize = True
    if args.dataset == 'imagenet32x32':
        n_classes = 1000
        args.batch_size = 256
    elif args.dataset in ["cifar10", "stl10"]:
        n_classes = 9
    elif args.dataset in ["usps", "mnist", "svhn", 'syn_digits']:
        n_classes = 10
    else:
        raise ValueError('invalid dataset option: {}'.format(args.dataset))

    kwargs = {'num_workers': 2, 'pin_memory': True} if use_cuda else {}
    assert (args.val_ratio >= 0. and args.val_ratio < 1.)
    assert (args.train_ratio > 0. and args.train_ratio <= 1.)
    train_ds = get_dataset(args.dataset,
                           'train',
                           use_normalize=use_normalize,
                           test_size=args.val_ratio,
                           train_size=args.train_ratio)
    train_loader = torch.utils.data.DataLoader(train_ds,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               **kwargs)
    test_loader = torch.utils.data.DataLoader(get_dataset(
        args.dataset,
        'test',
        use_normalize=use_normalize,
        test_size=args.val_ratio,
        train_size=args.train_ratio),
                                              batch_size=args.test_batch_size,
                                              shuffle=True,
                                              **kwargs)
    if args.val_ratio == 0.0:
        val_loader = test_loader
    else:
        val_ds = get_dataset(args.dataset,
                             'val',
                             use_normalize=use_normalize,
                             test_size=args.val_ratio,
                             train_size=args.train_ratio)
        val_loader = torch.utils.data.DataLoader(val_ds,
                                                 batch_size=args.batch_size,
                                                 shuffle=True,
                                                 **kwargs)

    if args.arch == "DTN":
        model = network.DTN().to(device)
    elif args.arch == 'wrn':
        model = network.WideResNet(depth=28,
                                   num_classes=n_classes,
                                   widen_factor=10,
                                   dropRate=0.0).to(device)
    else:
        raise ValueError('invalid network architecture {}'.format(args.arch))

    if args.ckpt_file is not None:
        print('initialize model parameters from {}'.format(args.ckpt_file))
        model.restore_from_ckpt(torch.load(args.ckpt_file, map_location='cpu'),
                                exclude_vars=args.exclude_vars.split(',')
                                if args.exclude_vars is not None else [])
        print('accuracy on test set before fine-tuning')
        test(args, model, device, test_loader)

    if args.resume is not None:
        assert (os.path.isfile(args.resume))
        print('resume training from {}'.format(args.resume))
        model.load_state_dict(torch.load(args.resume))

    if use_cuda:
        # model = torch.nn.DataParallel(model)
        cudnn.benchmark = True

    if args.dataset.startswith("cifar") or args.dataset in ['stl10']:
        lr_decay_step = 100
        lr_decay_rate = 0.1
        PATIENCE = 100
        optimizer = optim.SGD(model.get_parameters(args.lr),
                              momentum=args.momentum,
                              weight_decay=args.wd)
        scheduler = MultiStepLR(optimizer, milestones=[150, 250], gamma=0.1)
    elif args.dataset in ["mnist", "usps", "svhn", "syn_digits"]:
        lr_decay_step = 50
        lr_decay_rate = 0.5
        if args.dataset == 'svhn':
            PATIENCE = 10
        else:
            PATIENCE = 50
        optimizer = optim.SGD(model.get_parameters(args.lr),
                              momentum=0.5,
                              weight_decay=args.wd)
        scheduler = StepLR(optimizer,
                           step_size=lr_decay_step,
                           gamma=lr_decay_rate)
    elif args.dataset == 'imagenet32x32':
        PATIENCE = 10
        lr_decay_step = 10
        lr_decay_rate = 0.2
        optimizer = torch.optim.SGD(model.get_parameters(args.lr),
                                    momentum=0.9,
                                    weight_decay=5e-4,
                                    nesterov=True)
        scheduler = StepLR(optimizer,
                           step_size=lr_decay_step,
                           gamma=lr_decay_rate)
    else:
        raise ValueError("invalid dataset option: {}".format(args.dataset))

    early_stop_engine = EarlyStopping(PATIENCE)

    print("args:{}".format(args))

    # start training.
    best_accuracy = 0.
    save_path = os.path.join(args.output_path, "model.pt")
    time_stats = []
    for epoch in range(1, args.epochs + 1):
        start_time = time.time()
        train(args, model, device, train_loader, optimizer, epoch, writer)
        training_time = time.time() - start_time
        print('epoch: {} training time: {:.2f}'.format(epoch, training_time))
        time_stats.append(training_time)

        val_accuracy = test(args, model, device, val_loader)
        scheduler.step()

        writer.add_scalar("val_accuracy", val_accuracy, epoch)
        if val_accuracy >= best_accuracy:
            best_accuracy = val_accuracy
            torch.save(model.state_dict(), save_path)

        if epoch % 20 == 0:
            print('accuracy on test set at epoch {}'.format(epoch))
            test(args, model, device, test_loader)

        if early_stop_engine.is_stop_training(val_accuracy):
            print(
                "no improvement after {}, stop training at epoch {}\n".format(
                    PATIENCE, epoch))
            break

    # print('finish training {} epochs'.format(args.epochs))
    mean_training_time = np.mean(np.array(time_stats))
    print('Average training_time: {}'.format(mean_training_time))
    print('load ckpt with best validation accuracy from {}'.format(save_path))
    model.load_state_dict(torch.load(save_path, map_location='cpu'))
    test_accuracy = test(args, model, device, test_loader)

    writer.add_scalar("test_accuracy", test_accuracy, args.epochs)
    with open(os.path.join(args.output_path, 'accuracy.pkl'),
              'wb') as pkl_file:
        pkl.dump(
            {
                'train': best_accuracy,
                'test': test_accuracy,
                'training_time': mean_training_time
            }, pkl_file)
예제 #2
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='CDAN SVHN MNIST')
    parser.add_argument('method',
                        type=str,
                        default='CDAN-E',
                        choices=['CDAN', 'CDAN-E', 'DANN'])
    parser.add_argument('--task', default='USPS2MNIST', help='task to perform')
    parser.add_argument('--batch_size',
                        type=int,
                        default=64,
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test_batch_size',
                        type=int,
                        default=1000,
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=10,
                        metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--lr', type=float, default=0.03, metavar='LR')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.5,
                        metavar='M',
                        help='SGD momentum (default: 0.5)')
    parser.add_argument('--gpu_id', type=str, help='cuda device id')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument(
        '--log_interval',
        type=int,
        default=10,
        help='how many batches to wait before logging training status')
    parser.add_argument('--random',
                        type=bool,
                        default=False,
                        help='whether to use random')
    args = parser.parse_args()

    torch.manual_seed(args.seed)
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id

    source_list = '../data/svhn2mnist/svhn_balanced.txt'
    target_list = '../data/svhn2mnist/mnist_train.txt'
    test_list = '../data/svhn2mnist/mnist_test.txt'

    train_loader = torch.utils.data.DataLoader(ImageList(
        open(source_list).readlines(),
        transform=transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.5, ), (0.5, ))]),
        mode='RGB'),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=1)
    train_loader1 = torch.utils.data.DataLoader(ImageList(
        open(target_list).readlines(),
        transform=transforms.Compose([
            transforms.Resize((32, 32)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, ), (0.5, ))
        ]),
        mode='RGB'),
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                num_workers=1)
    test_loader = torch.utils.data.DataLoader(ImageList(
        open(test_list).readlines(),
        transform=transforms.Compose([
            transforms.Resize((32, 32)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, ), (0.5, ))
        ]),
        mode='RGB'),
                                              batch_size=args.test_batch_size,
                                              shuffle=True,
                                              num_workers=1)

    model = network.DTN()
    model = model.cuda()
    class_num = 10

    if args.random:
        random_layer = network.RandomLayer([model.output_num(), class_num],
                                           500)
        ad_net = network.AdversarialNetwork(500, 500)
        random_layer.cuda()
    else:
        random_layer = None
        ad_net = network.AdversarialNetwork(model.output_num() * class_num,
                                            500)
    ad_net = ad_net.cuda()
    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          weight_decay=0.0005,
                          momentum=0.9)
    optimizer_ad = optim.SGD(ad_net.parameters(),
                             lr=args.lr,
                             weight_decay=0.0005,
                             momentum=0.9)

    for epoch in range(1, args.epochs + 1):
        if epoch % 3 == 0:
            for param_group in optimizer.param_groups:
                param_group["lr"] = param_group["lr"] * 0.3
        train(args, model, ad_net, random_layer, train_loader, train_loader1,
              optimizer, optimizer_ad, epoch, 0, args.method)
        test(args, model, test_loader)
예제 #3
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='CDAN SVHN MNIST')
    parser.add_argument('--method',
                        type=str,
                        default='CDAN-E',
                        choices=['CDAN', 'CDAN-E', 'DANN'])
    parser.add_argument('--task', default='USPS2MNIST', help='task to perform')
    parser.add_argument('--batch_size',
                        type=int,
                        default=256,
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test_batch_size',
                        type=int,
                        default=1000,
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=10,
                        metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--lr', type=float, default=0.03, metavar='LR')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.5,
                        metavar='M',
                        help='SGD momentum (default: 0.5)')
    parser.add_argument('--gpu_id',
                        type=str,
                        default='0',
                        help='cuda device id')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument(
        '--log_interval',
        type=int,
        default=50,
        help='how many batches to wait before logging training status')
    parser.add_argument('--random',
                        type=bool,
                        default=False,
                        help='whether to use random')
    parser.add_argument('--output_dir', type=str, default="digits/s2m")
    parser.add_argument('--cla_plus_weight', type=float, default=0.3)
    parser.add_argument('--cyc_loss_weight', type=float, default=0.01)
    parser.add_argument('--weight_in_loss_g',
                        type=str,
                        default='1,0.01,0.1,0.1')
    args = parser.parse_args()

    torch.manual_seed(args.seed)
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id

    source_list = '../data/svhn2mnist/svhn_balanced.txt'
    target_list = '../data/svhn2mnist/mnist_train.txt'
    test_list = '../data/svhn2mnist/mnist_test.txt'
    # train config
    config = {}

    config['method'] = args.method
    config["gpu"] = args.gpu_id
    config['cyc_loss_weight'] = args.cyc_loss_weight
    config['cla_plus_weight'] = args.cla_plus_weight
    config['weight_in_loss_g'] = args.weight_in_loss_g
    config["epochs"] = args.epochs
    config["output_for_test"] = True
    config["output_path"] = "snapshot/" + args.output_dir
    if not osp.exists(config["output_path"]):
        os.system('mkdir -p ' + config["output_path"])
    config["out_file"] = open(
        osp.join(
            config["output_path"], "log_svhn_to_mnist_{}.txt".format(
                str(datetime.datetime.utcnow()))), "w")

    config["out_file"].write(str(config))
    config["out_file"].flush()

    train_loader = torch.utils.data.DataLoader(ImageList(
        open(source_list).readlines(),
        transform=transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.5, ), (0.5, ))]),
        mode='RGB'),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=1)
    train_loader1 = torch.utils.data.DataLoader(ImageList(
        open(target_list).readlines(),
        transform=transforms.Compose([
            transforms.Resize((32, 32)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, ), (0.5, ))
        ]),
        mode='RGB'),
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                num_workers=1)
    test_loader = torch.utils.data.DataLoader(ImageList(
        open(test_list).readlines(),
        transform=transforms.Compose([
            transforms.Resize((32, 32)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, ), (0.5, ))
        ]),
        mode='RGB'),
                                              batch_size=args.test_batch_size,
                                              shuffle=True,
                                              num_workers=1)

    model = network.DTN()
    model = model.cuda()
    class_num = 10

    #添加G,D,和额外的分类器
    z_dimension = 512
    D_s = network.models["Discriminator_digits"]()
    D_s = D_s.cuda()
    G_s2t = network.models["Generator_digits"](z_dimension, 1024)
    G_s2t = G_s2t.cuda()

    D_t = network.models["Discriminator_digits"]()
    D_t = D_t.cuda()
    G_t2s = network.models["Generator_digits"](z_dimension, 1024)
    G_t2s = G_t2s.cuda()

    criterion_GAN = torch.nn.MSELoss()
    criterion_cycle = torch.nn.L1Loss()
    criterion_identity = torch.nn.L1Loss()
    criterion_Sem = torch.nn.L1Loss()

    optimizer_G = torch.optim.Adam(itertools.chain(G_s2t.parameters(),
                                                   G_t2s.parameters()),
                                   lr=0.0003)
    optimizer_D_s = torch.optim.Adam(D_s.parameters(), lr=0.0003)
    optimizer_D_t = torch.optim.Adam(D_t.parameters(), lr=0.0003)

    fake_S_buffer = ReplayBuffer()
    fake_T_buffer = ReplayBuffer()

    ## 添加分类器
    classifier1 = net.Net(512, class_num)
    classifier1 = classifier1.cuda()
    classifier1_optim = optim.Adam(classifier1.parameters(), lr=0.0003)

    if args.random:
        random_layer = network.RandomLayer([model.output_num(), class_num],
                                           500)
        ad_net = network.AdversarialNetwork(500, 500)
        random_layer.cuda()
    else:
        random_layer = None
        ad_net = network.AdversarialNetwork(model.output_num() * class_num,
                                            500)
    ad_net = ad_net.cuda()
    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          weight_decay=0.0005,
                          momentum=0.9)
    optimizer_ad = optim.SGD(ad_net.parameters(),
                             lr=args.lr,
                             weight_decay=0.0005,
                             momentum=0.9)

    for epoch in range(1, args.epochs + 1):
        if epoch % 3 == 0:
            for param_group in optimizer.param_groups:
                param_group["lr"] = param_group["lr"] * 0.3
        train(args, model, ad_net, random_layer, train_loader, train_loader1,
              optimizer, optimizer_ad, epoch, 0, args.method, D_s, D_t, G_s2t,
              G_t2s, criterion_Sem, criterion_GAN, criterion_cycle,
              criterion_identity, optimizer_G, optimizer_D_t, optimizer_D_s,
              classifier1, classifier1_optim, fake_S_buffer, fake_T_buffer)
        test(args, epoch, config, model, test_loader)
    # Define what device we are using
    print("CUDA Available: ",torch.cuda.is_available())
    device = torch.device("cuda" if (use_cuda and torch.cuda.is_available()) else "cpu")

    # Initialize the network
    if args.dataset == "cifar100":
        args.n_classes = 100
    elif args.dataset in ["cifar10", "stl10"]:
        args.n_classes = 9
    elif args.dataset in ["usps", "mnist", "svhn", 'syn_digits', ]:
        args.n_classes = 10
    else:
        raise ValueError('invalid dataset option: {}'.format(args.dataset))

    if args.arch == "DTN":
        model_A = network.DTN().to(device)
    elif args.arch == 'wrn':
        model_A = network.WideResNet(depth=28, num_classes=args.n_classes, widen_factor=10, dropRate=0.0).to(device)
    else:
        raise ValueError('invalid network architecture {}'.format(args.arch))
    
    model_B = copy.deepcopy(model_A).to(device)

    # Load the pretrained model
    model_A.load_state_dict(torch.load(args.ckpt_a, map_location='cpu'))
    model_B.load_state_dict(torch.load(args.ckpt_b, map_location='cpu'))

    # Set the model in evaluation mode. In this case this is for the Dropout layers
    model_A.eval()
    model_B.eval()
예제 #5
0
def main():
    parser = argparse.ArgumentParser(description='CDAN SVHN MNIST')
    parser.add_argument('--method',
                        type=str,
                        default='CDAN-E',
                        choices=['CDAN', 'CDAN-E', 'DANN'])
    parser.add_argument('--task', default='USPS2MNIST', help='task to perform')
    parser.add_argument('--batch_size',
                        type=int,
                        default=256,
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test_batch_size',
                        type=int,
                        default=1000,
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=100,
                        metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--lr', type=float, default=0.03, metavar='LR')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.5,
                        metavar='M',
                        help='SGD momentum (default: 0.5)')
    parser.add_argument('--gpu_id',
                        default='0',
                        type=str,
                        help='cuda device id')
    parser.add_argument('--seed',
                        type=int,
                        default=40,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument(
        '--log_interval',
        type=int,
        default=10,
        help='how many batches to wait before logging training status')
    parser.add_argument('--random',
                        type=bool,
                        default=False,
                        help='whether to use random')
    parser.add_argument("--mdd_weight", type=float, default=0)
    parser.add_argument("--entropic_weight", type=float, default=0)
    parser.add_argument("--weight", type=float, default=1)
    parser.add_argument("--left_weight", type=float, default=1)
    parser.add_argument("--right_weight", type=float, default=1)
    parser.add_argument('--use_seed', type=int, default=1)
    args = parser.parse_args()
    if args.use_seed:
        import random
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)
        random.seed(args.seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

    import os.path as osp
    import time
    config = {}

    config["use_seed"] = args.use_seed
    config['seed'] = args.seed
    config["output_path"] = "snapshot/s2m"
    config["mdd_weight"] = args.mdd_weight
    config["entropic_weight"] = args.entropic_weight
    config["weight"] = args.weight
    config["left_weight"] = args.left_weight
    config["right_weight"] = args.right_weight
    if not osp.exists(config["output_path"]):
        os.system('mkdir -p ' + config["output_path"])
    config["out_file"] = open(
        osp.join(
            config["output_path"],
            "log_svhn_to_mnist_{}______{}.txt".format(str(int(time.time())),
                                                      str(args.seed))), "w")

    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id

    source_list = 'data/svhn2mnist/svhn_balanced.txt'
    target_list = 'data/svhn2mnist/mnist_train.txt'
    test_list = 'data/svhn2mnist/mnist_test.txt'

    train_loader = torch.utils.data.DataLoader(ImageList(
        open(source_list).readlines(),
        transform=transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.5, ), (0.5, ))]),
        mode='RGB'),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=0)
    train_loader1 = torch.utils.data.DataLoader(ImageList(
        open(target_list).readlines(),
        transform=transforms.Compose([
            transforms.Resize((32, 32)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, ), (0.5, ))
        ]),
        mode='RGB'),
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                num_workers=0)
    test_loader = torch.utils.data.DataLoader(ImageList(
        open(test_list).readlines(),
        transform=transforms.Compose([
            transforms.Resize((32, 32)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, ), (0.5, ))
        ]),
        mode='RGB'),
                                              batch_size=args.test_batch_size,
                                              shuffle=True,
                                              num_workers=0)

    model = network.DTN()
    model = model.cuda()
    class_num = 10

    if args.random:
        random_layer = network.RandomLayer([model.output_num(), class_num],
                                           500)
        ad_net = network.AdversarialNetwork(500, 500)
        random_layer.cuda()
    else:
        random_layer = None
        ad_net = network.AdversarialNetwork(model.output_num() * class_num,
                                            500)
    ad_net = ad_net.cuda()
    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          weight_decay=0.0005,
                          momentum=0.9)
    optimizer_ad = optim.SGD(ad_net.parameters(),
                             lr=args.lr,
                             weight_decay=0.0005,
                             momentum=0.9)

    config["out_file"].write(str(config))
    config["out_file"].flush()
    best_model = model
    best_acc = 0

    for epoch in range(1, args.epochs + 1):
        if epoch % 3 == 0:
            for param_group in optimizer.param_groups:
                param_group["lr"] = param_group["lr"] * 0.3

        train(args, config, model, ad_net, random_layer, train_loader,
              train_loader1, optimizer, optimizer_ad, epoch)
        acc = test(epoch, config, model, test_loader)
        if (acc > best_acc):
            best_model = model
            best_acc = acc

    torch.save(
        best_model,
        osp.join("snapshot/s2m_model",
                 "s2m_{}_{}".format(str(best_acc), str(args.mdd_weight))))