예제 #1
0
def create_model(num_classes,model_type):
    if model_type == 'wide22':
        model = wideresnet.WideResNet(22, num_classes, widen_factor=1, dropRate=0.0, leakyRate=0.1)
    elif model_type == 'wide28':
        model = wideresnet.WideResNet(28, num_classes, widen_factor=2, dropRate=0.0, leakyRate=0.1) 
    elif model_type == 'wide28_2':
        model = wideresnet2.WideResNet(num_classes)
    return model
예제 #2
0
def two_head_net(model, out_features, fileout='', pre_train=False):
    print(pre_train)
    if pre_train:
        if model == 'densenet':
            two_head_net = d.DenseNet3(100, out_features).cuda()

            checkpoint = torch.load(fileout, map_location='cuda:0')
            two_head_net.load_state_dict(checkpoint)

            Linear = list(two_head_net.children())[-1]
            Linear = Linear.state_dict()
            Linear1 = nn.Linear(in_features=342,
                                out_features=out_features,
                                bias=True)
            Linear1.load_state_dict(Linear)
            Linear1.cuda()
            Linear2 = nn.Linear(in_features=342,
                                out_features=out_features,
                                bias=True)
            Linear2.load_state_dict(Linear)
            Linear2.cuda()

            two_head_net = two_head_dense(two_head_net, Linear1, Linear2)
        elif model == 'wideresnet':
            two_head_net = wrn.WideResNet(out_features).cuda()
            checkpoint = torch.load(fileout, map_location='cuda:0')
            two_head_net.load_state_dict(checkpoint)

            Linear = list(two_head_net.children())[-1]
            Linear = Linear.state_dict()
            Linear1 = nn.Linear(in_features=640,
                                out_features=out_features,
                                bias=True)
            Linear1.load_state_dict(Linear)
            Linear1.cuda()
            Linear2 = nn.Linear(in_features=640,
                                out_features=out_features,
                                bias=True)
            Linear2.load_state_dict(Linear)
            Linear2.cuda()

            two_head_net = two_head_wide(two_head_net, Linear1, Linear2)
    else:
        if model == 'densenet':
            two_head_net = d.DenseNet3(100, out_features).cuda()
            Linear1 = nn.Linear(in_features=342,
                                out_features=out_features,
                                bias=True)
            Linear1.cuda()
            Linear2 = nn.Linear(in_features=342,
                                out_features=out_features,
                                bias=True)
            Linear2.load_state_dict(Linear1.state_dict())
            Linear2.cuda()
            two_head_net = two_head_dense(two_head_net, Linear1, Linear2)
        elif model == 'wideresnet':
            two_head_net = wrn.WideResNet(out_features).cuda()
            Linear1 = nn.Linear(in_features=342,
                                out_features=out_features,
                                bias=True)
            Linear1.cuda()
            Linear2 = nn.Linear(in_features=342,
                                out_features=out_features,
                                bias=True)
            Linear2.load_state_dict(Linear1.state_dict())
            Linear2.cuda()
            two_head_net = two_head_wide(two_head_net, Linear1, Linear2)

    return two_head_net
예제 #3
0
def main(args=args):
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    zca_mean = None
    zca_components = None

    # build dataset
    if args.dataset == "Cifar10":
        dataset_base_path = path.join(args.base_path, "dataset", "cifar")
        train_dataset = cifar10_dataset(dataset_base_path)
        test_dataset = cifar10_dataset(dataset_base_path, train_flag=False)
        sampler_valid, sampler_train_l, sampler_train_u = get_ssl_sampler(
            torch.tensor(train_dataset.targets, dtype=torch.int32), 500, 400,
            10)
        test_dloader = DataLoader(test_dataset,
                                  batch_size=args.batch_size,
                                  num_workers=args.workers,
                                  pin_memory=True)
        valid_dloader = DataLoader(train_dataset,
                                   batch_size=args.batch_size,
                                   num_workers=args.workers,
                                   pin_memory=True,
                                   sampler=sampler_valid)
        train_dloader_l = DataLoader(train_dataset,
                                     batch_size=args.batch_size,
                                     num_workers=args.workers,
                                     pin_memory=True,
                                     sampler=sampler_train_l)
        train_dloader_u = DataLoader(train_dataset,
                                     batch_size=args.batch_size,
                                     num_workers=args.workers,
                                     pin_memory=True,
                                     sampler=sampler_train_u)
        num_classes = 10
        if args.zca:
            zca_mean = np.load(
                os.path.join(dataset_base_path, 'cifar10_zca_mean.npy'))
            zca_components = np.load(
                os.path.join(dataset_base_path, 'cifar10_zca_components.npy'))
            zca_mean = torch.from_numpy(zca_mean).view(1, -1).float().cuda()
            zca_components = torch.from_numpy(zca_components).float().cuda()
    elif args.dataset == "Cifar100":
        dataset_base_path = path.join(args.base_path, "dataset", "cifar")
        train_dataset = cifar100_dataset(dataset_base_path)
        test_dataset = cifar100_dataset(dataset_base_path, train_flag=False)
        sampler_valid, sampler_train_l, sampler_train_u = get_ssl_sampler(
            torch.tensor(train_dataset.targets, dtype=torch.int32), 50, 40,
            100)
        test_dloader = DataLoader(test_dataset,
                                  batch_size=args.batch_size,
                                  num_workers=args.workers,
                                  pin_memory=True)
        valid_dloader = DataLoader(train_dataset,
                                   batch_size=args.batch_size,
                                   num_workers=args.workers,
                                   pin_memory=True,
                                   sampler=sampler_valid)
        train_dloader_l = DataLoader(train_dataset,
                                     batch_size=args.batch_size,
                                     num_workers=args.workers,
                                     pin_memory=True,
                                     sampler=sampler_train_l)
        train_dloader_u = DataLoader(train_dataset,
                                     batch_size=args.batch_size,
                                     num_workers=args.workers,
                                     pin_memory=True,
                                     sampler=sampler_train_u)
        num_classes = 100
    elif args.dataset == "SVHN":
        dataset_base_path = path.join(args.base_path, "dataset", "svhn")
        train_dataset = svhn_dataset(dataset_base_path)
        test_dataset = svhn_dataset(dataset_base_path, train_flag=False)
        sampler_valid, sampler_train_l, sampler_train_u = get_ssl_sampler(
            torch.tensor(train_dataset.labels, dtype=torch.int32), 732, 100,
            10)
        test_dloader = DataLoader(test_dataset,
                                  batch_size=args.batch_size,
                                  num_workers=args.workers,
                                  pin_memory=True)
        valid_dloader = DataLoader(train_dataset,
                                   batch_size=args.batch_size,
                                   num_workers=args.workers,
                                   pin_memory=True,
                                   sampler=sampler_valid)
        train_dloader_l = DataLoader(train_dataset,
                                     batch_size=args.batch_size,
                                     num_workers=args.workers,
                                     pin_memory=True,
                                     sampler=sampler_train_l)
        train_dloader_u = DataLoader(train_dataset,
                                     batch_size=args.batch_size,
                                     num_workers=args.workers,
                                     pin_memory=True,
                                     sampler=sampler_train_u)
        num_classes = 10
    else:
        raise NotImplementedError("Dataset {} Not Implemented".format(
            args.dataset))
    if args.net_name == "wideresnet":
        model = wideresnet.WideResNet(depth=args.depth,
                                      width=args.width,
                                      num_classes=num_classes,
                                      data_parallel=args.dp,
                                      drop_rate=args.dr)
    elif "preact" in args.net_name:
        model = get_preact_resnet(args.net_name,
                                  num_classes=num_classes,
                                  data_parallel=args.dp,
                                  drop_rate=args.dr)
    elif "densenet" in args.net_name:
        model = get_densenet(args.net_name,
                             num_classes=num_classes,
                             data_parallel=args.dp,
                             drop_rate=args.dr)
    else:
        raise NotImplementedError("model {} not implemented".format(
            args.net_name))
    model = model.cuda()

    input(
        "Begin the {} time's semi-supervised training, Dataset:{} Mixup Method:{} \
    Manifold Mixup Method :{}".format(args.train_time, args.dataset,
                                      args.mixup, args.manifold_mixup))
    criterion_l = nn.CrossEntropyLoss()
    criterion_u = nn.MSELoss()
    if args.optimizer == "SGD":
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.wd,
                                    nesterov=args.nesterov)
    else:
        raise NotImplementedError("{} not find".format(args.optimizer))
    scheduler = MultiStepLR(optimizer,
                            milestones=args.adjust_lr,
                            gamma=args.lr_decay_ratio)
    writer_log_dir = "{}/{}-SSL/runs/train_time:{}".format(
        args.base_path, args.dataset, args.train_time)
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            if args.resume_arg:
                args = checkpoint['args']
                args.start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            raise FileNotFoundError(
                "Checkpoint Resume File {} Not Found".format(args.resume))
    else:
        if os.path.exists(writer_log_dir):
            flag = input(
                "{}-SSL train_time:{} will be removed, input yes to continue:".
                format(args.dataset, args.train_time))
            if flag == "yes":
                shutil.rmtree(writer_log_dir, ignore_errors=True)
    writer = SummaryWriter(log_dir=writer_log_dir)
    for epoch in range(args.start_epoch, args.epochs):
        scheduler.step(epoch)
        if epoch == 0:
            # do warm up
            modify_lr_rate(opt=optimizer, lr=args.wul)
        alpha = alpha_schedule(epoch=epoch)
        train(train_dloader_l,
              train_dloader_u,
              model=model,
              criterion_l=criterion_l,
              criterion_u=criterion_u,
              optimizer=optimizer,
              epoch=epoch,
              writer=writer,
              alpha=alpha,
              zca_mean=zca_mean,
              zca_components=zca_components)
        test(valid_dloader,
             test_dloader,
             model=model,
             criterion=criterion_l,
             epoch=epoch,
             writer=writer,
             num_classes=num_classes,
             zca_mean=zca_mean,
             zca_components=zca_components)
        save_checkpoint({
            'epoch': epoch + 1,
            'args': args,
            "state_dict": model.state_dict(),
            'optimizer': optimizer.state_dict(),
        })
        if epoch == 0:
            modify_lr_rate(opt=optimizer, lr=args.lr)