def create_model(num_classes,model_type): if model_type == 'wide22': model = wideresnet.WideResNet(22, num_classes, widen_factor=1, dropRate=0.0, leakyRate=0.1) elif model_type == 'wide28': model = wideresnet.WideResNet(28, num_classes, widen_factor=2, dropRate=0.0, leakyRate=0.1) elif model_type == 'wide28_2': model = wideresnet2.WideResNet(num_classes) return model
def two_head_net(model, out_features, fileout='', pre_train=False): print(pre_train) if pre_train: if model == 'densenet': two_head_net = d.DenseNet3(100, out_features).cuda() checkpoint = torch.load(fileout, map_location='cuda:0') two_head_net.load_state_dict(checkpoint) Linear = list(two_head_net.children())[-1] Linear = Linear.state_dict() Linear1 = nn.Linear(in_features=342, out_features=out_features, bias=True) Linear1.load_state_dict(Linear) Linear1.cuda() Linear2 = nn.Linear(in_features=342, out_features=out_features, bias=True) Linear2.load_state_dict(Linear) Linear2.cuda() two_head_net = two_head_dense(two_head_net, Linear1, Linear2) elif model == 'wideresnet': two_head_net = wrn.WideResNet(out_features).cuda() checkpoint = torch.load(fileout, map_location='cuda:0') two_head_net.load_state_dict(checkpoint) Linear = list(two_head_net.children())[-1] Linear = Linear.state_dict() Linear1 = nn.Linear(in_features=640, out_features=out_features, bias=True) Linear1.load_state_dict(Linear) Linear1.cuda() Linear2 = nn.Linear(in_features=640, out_features=out_features, bias=True) Linear2.load_state_dict(Linear) Linear2.cuda() two_head_net = two_head_wide(two_head_net, Linear1, Linear2) else: if model == 'densenet': two_head_net = d.DenseNet3(100, out_features).cuda() Linear1 = nn.Linear(in_features=342, out_features=out_features, bias=True) Linear1.cuda() Linear2 = nn.Linear(in_features=342, out_features=out_features, bias=True) Linear2.load_state_dict(Linear1.state_dict()) Linear2.cuda() two_head_net = two_head_dense(two_head_net, Linear1, Linear2) elif model == 'wideresnet': two_head_net = wrn.WideResNet(out_features).cuda() Linear1 = nn.Linear(in_features=342, out_features=out_features, bias=True) Linear1.cuda() Linear2 = nn.Linear(in_features=342, out_features=out_features, bias=True) Linear2.load_state_dict(Linear1.state_dict()) Linear2.cuda() two_head_net = two_head_wide(two_head_net, Linear1, Linear2) return two_head_net
def main(args=args): os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu zca_mean = None zca_components = None # build dataset if args.dataset == "Cifar10": dataset_base_path = path.join(args.base_path, "dataset", "cifar") train_dataset = cifar10_dataset(dataset_base_path) test_dataset = cifar10_dataset(dataset_base_path, train_flag=False) sampler_valid, sampler_train_l, sampler_train_u = get_ssl_sampler( torch.tensor(train_dataset.targets, dtype=torch.int32), 500, 400, 10) test_dloader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.workers, pin_memory=True) valid_dloader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.workers, pin_memory=True, sampler=sampler_valid) train_dloader_l = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.workers, pin_memory=True, sampler=sampler_train_l) train_dloader_u = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.workers, pin_memory=True, sampler=sampler_train_u) num_classes = 10 if args.zca: zca_mean = np.load( os.path.join(dataset_base_path, 'cifar10_zca_mean.npy')) zca_components = np.load( os.path.join(dataset_base_path, 'cifar10_zca_components.npy')) zca_mean = torch.from_numpy(zca_mean).view(1, -1).float().cuda() zca_components = torch.from_numpy(zca_components).float().cuda() elif args.dataset == "Cifar100": dataset_base_path = path.join(args.base_path, "dataset", "cifar") train_dataset = cifar100_dataset(dataset_base_path) test_dataset = cifar100_dataset(dataset_base_path, train_flag=False) sampler_valid, sampler_train_l, sampler_train_u = get_ssl_sampler( torch.tensor(train_dataset.targets, dtype=torch.int32), 50, 40, 100) test_dloader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.workers, pin_memory=True) valid_dloader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.workers, pin_memory=True, sampler=sampler_valid) train_dloader_l = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.workers, pin_memory=True, sampler=sampler_train_l) train_dloader_u = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.workers, pin_memory=True, sampler=sampler_train_u) num_classes = 100 elif args.dataset == "SVHN": dataset_base_path = path.join(args.base_path, "dataset", "svhn") train_dataset = svhn_dataset(dataset_base_path) test_dataset = svhn_dataset(dataset_base_path, train_flag=False) sampler_valid, sampler_train_l, sampler_train_u = get_ssl_sampler( torch.tensor(train_dataset.labels, dtype=torch.int32), 732, 100, 10) test_dloader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.workers, pin_memory=True) valid_dloader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.workers, pin_memory=True, sampler=sampler_valid) train_dloader_l = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.workers, pin_memory=True, sampler=sampler_train_l) train_dloader_u = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.workers, pin_memory=True, sampler=sampler_train_u) num_classes = 10 else: raise NotImplementedError("Dataset {} Not Implemented".format( args.dataset)) if args.net_name == "wideresnet": model = wideresnet.WideResNet(depth=args.depth, width=args.width, num_classes=num_classes, data_parallel=args.dp, drop_rate=args.dr) elif "preact" in args.net_name: model = get_preact_resnet(args.net_name, num_classes=num_classes, data_parallel=args.dp, drop_rate=args.dr) elif "densenet" in args.net_name: model = get_densenet(args.net_name, num_classes=num_classes, data_parallel=args.dp, drop_rate=args.dr) else: raise NotImplementedError("model {} not implemented".format( args.net_name)) model = model.cuda() input( "Begin the {} time's semi-supervised training, Dataset:{} Mixup Method:{} \ Manifold Mixup Method :{}".format(args.train_time, args.dataset, args.mixup, args.manifold_mixup)) criterion_l = nn.CrossEntropyLoss() criterion_u = nn.MSELoss() if args.optimizer == "SGD": optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.wd, nesterov=args.nesterov) else: raise NotImplementedError("{} not find".format(args.optimizer)) scheduler = MultiStepLR(optimizer, milestones=args.adjust_lr, gamma=args.lr_decay_ratio) writer_log_dir = "{}/{}-SSL/runs/train_time:{}".format( args.base_path, args.dataset, args.train_time) if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) if args.resume_arg: args = checkpoint['args'] args.start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: raise FileNotFoundError( "Checkpoint Resume File {} Not Found".format(args.resume)) else: if os.path.exists(writer_log_dir): flag = input( "{}-SSL train_time:{} will be removed, input yes to continue:". format(args.dataset, args.train_time)) if flag == "yes": shutil.rmtree(writer_log_dir, ignore_errors=True) writer = SummaryWriter(log_dir=writer_log_dir) for epoch in range(args.start_epoch, args.epochs): scheduler.step(epoch) if epoch == 0: # do warm up modify_lr_rate(opt=optimizer, lr=args.wul) alpha = alpha_schedule(epoch=epoch) train(train_dloader_l, train_dloader_u, model=model, criterion_l=criterion_l, criterion_u=criterion_u, optimizer=optimizer, epoch=epoch, writer=writer, alpha=alpha, zca_mean=zca_mean, zca_components=zca_components) test(valid_dloader, test_dloader, model=model, criterion=criterion_l, epoch=epoch, writer=writer, num_classes=num_classes, zca_mean=zca_mean, zca_components=zca_components) save_checkpoint({ 'epoch': epoch + 1, 'args': args, "state_dict": model.state_dict(), 'optimizer': optimizer.state_dict(), }) if epoch == 0: modify_lr_rate(opt=optimizer, lr=args.lr)