def test(): device = torch.device('cpu') arguments = load_test_arguments() checkpoints = {} if arguments.checkpoint_file: checkpoints[arguments.checkpoint_file] = arguments.checkpoint_file else: for root, dirs, files in os.walk(arguments.checkpoint_dir): for file in files: if file.endswith('.bin'): checkpoints[file] = os.path.join(root, file) test_dataloader = load_test_data(arguments) for checkpoint_name in checkpoints: checkpoint = torch.load(checkpoints[checkpoint_name], map_location=device) log_checkpoint(checkpoint) ema_model = WideResNet(num_classes=10) ema_model.load_state_dict(checkpoint['ema_model_state']) batches = len(test_dataloader) metrics = {'test_steps': 0, 'test_accuracy': 0} test_progress_bar = tqdm(enumerate(test_dataloader)) for batch_step, batch in test_progress_bar: test_step(ema_model, batch, metrics, device) on_test_batch_end(batch_step, batches, metrics, test_progress_bar)
class WeightEMA(object): def __init__(self, model, ema_model, run_type=0, alpha=0.999): self.model = model self.ema_model = ema_model self.alpha = alpha if run_type == 1: self.tmp_model = Classifier(num_classes=10).cuda() else: self.tmp_model = WideResNet(num_classes=10).cuda() self.wd = 0.02 * args.lr for param, ema_param in zip(self.model.parameters(), self.ema_model.parameters()): ema_param.data.copy_(param.data) def step(self, bn=False): if bn: # copy batchnorm stats to ema model for ema_param, tmp_param in zip(self.ema_model.parameters(), self.tmp_model.parameters()): tmp_param.data.copy_(ema_param.data.detach()) self.ema_model.load_state_dict(self.model.state_dict()) for ema_param, tmp_param in zip(self.ema_model.parameters(), self.tmp_model.parameters()): ema_param.data.copy_(tmp_param.data.detach()) else: one_minus_alpha = 1.0 - self.alpha for param, ema_param in zip(self.model.parameters(), self.ema_model.parameters()): ema_param.data.mul_(self.alpha) ema_param.data.add_(param.data.detach() * one_minus_alpha) # customized weight decay param.data.mul_(1 - self.wd)
def __init__(self, batch_size, model_params, n_steps, optimizer, adam, sgd, steps_validation, steps_checkpoint, dataset, save_path): self.n_steps = n_steps self.start_step = 0 self.steps_validation = steps_validation self.steps_checkpoint = steps_checkpoint self.num_labeled = 50000 self.train_loader, _, self.val_loader, self.test_loader, self.lbl_idx, _, self.val_idx = \ get_dataloaders_with_index(path='../data', batch_size=batch_size, num_labeled=self.num_labeled, which_dataset=dataset, validation=False) print('Labeled samples: ' + str(len(self.train_loader.sampler))) self.batch_size = self.train_loader.batch_size self.device = torch.device( "cuda:0" if torch.cuda.is_available() else "cpu") print(self.device) depth, k, n_out = model_params self.model = WideResNet(depth=depth, k=k, n_out=n_out, bias=True).to(self.device) self.ema_model = WideResNet(depth=depth, k=k, n_out=n_out, bias=True).to(self.device) for param in self.ema_model.parameters(): param.detach_() if optimizer == 'adam': self.lr, self.weight_decay = adam self.momentum, self.lr_decay = None, None self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr) self.ema_optimizer = WeightEMA(self.model, self.ema_model, self.lr, alpha=0.999) else: self.lr, self.momentum, self.weight_decay, self.lr_decay = sgd self.optimizer = optim.SGD(self.model.parameters(), lr=self.lr, momentum=self.momentum, weight_decay=self.weight_decay, nesterov=True) self.ema_optimizer = None self.criterion = nn.CrossEntropyLoss() self.train_accuracies, self.train_losses, = [], [] self.val_accuracies, self.val_losses, = [], [] self.best_acc = 0 self.augment = Augment(K=1) self.path = save_path self.writer = SummaryWriter()
def __init__(self, model, ema_model, run_type=0, alpha=0.999): self.model = model self.ema_model = ema_model self.alpha = alpha if run_type == 1: self.tmp_model = Classifier(num_classes=10).cuda() else: self.tmp_model = WideResNet(num_classes=10).cuda() self.wd = 0.02 * args.lr for param, ema_param in zip(self.model.parameters(), self.ema_model.parameters()): ema_param.data.copy_(param.data)
def get_model(name, num_classes=10, normalize_input=False, pretrain=False): name_parts = name.split('-') if name_parts[0] == 'wrn': depth = int(name_parts[1]) widen = int(name_parts[2]) if not pretrain: model = WideResNet( depth=depth, num_classes=num_classes, widen_factor=widen) else: model = WideResNetPre( depth=depth, num_classes=num_classes, widen_factor=widen) elif name_parts[0] == 'ss': model = ShakeNet(dict(depth=int(name_parts[1]), base_channels=int(name_parts[2]), shake_forward=True, shake_backward=True, shake_image=True, input_shape=(1, 3, 32, 32), n_classes=num_classes, )) elif name_parts[0] == 'resnet': depth = int(name_parts[1]) resnets = {18: ResNet18, 50: ResNet50, 152: ResNet152} if depth in resnets: # regular resnets model = resnets[depth]() else: # CIFAR - ResNet model = ResNet(num_classes=num_classes, depth=depth) else: raise ValueError('Could not parse model name %s' % name) if normalize_input: model = Sequential(NormalizeInput(), model) return model
def setup(self, flags): torch.backends.cudnn.deterministic = flags.deterministic print('torch.backends.cudnn.deterministic:', torch.backends.cudnn.deterministic) fix_all_seed(flags.seed) if flags.dataset == 'cifar10': num_classes = 10 else: num_classes = 100 if flags.model == 'densenet': self.network = densenet(num_classes=num_classes) elif flags.model == 'wrn': self.network = WideResNet(flags.layers, num_classes, flags.widen_factor, flags.droprate) elif flags.model == 'allconv': self.network = AllConvNet(num_classes) elif flags.model == 'resnext': self.network = resnext29(num_classes=num_classes) else: raise Exception('Unknown model.') self.network = self.network.cuda() print(self.network) print('flags:', flags) if not os.path.exists(flags.logs): os.makedirs(flags.logs) flags_log = os.path.join(flags.logs, 'flags_log.txt') write_log(flags, flags_log)
def build_model(args): # model = ResNet32(args.dataset == 'cifar10' and 10 or 100) model = WideResNet(args.layers, args.dataset == 'cifar10' and 10 or 100, args.widen_factor, dropRate=args.droprate) # weights_init(model) # print('Number of model parameters: {}'.format( # sum([p.data.nelement() for p in model.params()]))) if torch.cuda.is_available(): model.cuda() torch.backends.cudnn.benchmark = True return model
def get_model(model, args): if model == 'alexnet': return alexnet() if model == 'resnet': return resnet(dataset=args.dataset) if model == 'wideresnet': return WideResNet(args.layers, args.dataset == 'cifar10' and 10 or 100, args.widen_factor, dropRate=args.droprate, gbn=args.gbn) if model == 'densenet': return densenet()
def create_model(ema=False): if args.type == 1: model = Classifier(num_classes=10) else: model = WideResNet(num_classes=10) model = model.cuda() if ema: for param in model.parameters(): param.detach_() return model
def get_model(name, num_classes=10, normalize_input=False): name_parts = name.split('-') if name_parts[0] == 'wrn': depth = int(name_parts[1]) widen = int(name_parts[2]) model = WideResNet( depth=depth, num_classes=num_classes, widen_factor=widen) elif name_parts[0] == 'ss': model = ShakeNet(dict(depth=int(name_parts[1]), base_channels=int(name_parts[2]), shake_forward=True, shake_backward=True, shake_image=True, input_shape=(1, 3, 32, 32), n_classes=num_classes, )) elif name_parts[0] == 'resnet': model = ResNet(num_classes=num_classes, depth=int(name_parts[1])) else: raise ValueError('Could not parse model name %s' % name) if normalize_input: model = Sequential(NormalizeInput(), model) return model
class FullySupervisedTrainer: def __init__(self, batch_size, model_params, n_steps, optimizer, adam, sgd, steps_validation, steps_checkpoint, dataset, save_path): self.n_steps = n_steps self.start_step = 0 self.steps_validation = steps_validation self.steps_checkpoint = steps_checkpoint self.num_labeled = 50000 self.train_loader, _, self.val_loader, self.test_loader, self.lbl_idx, _, self.val_idx = \ get_dataloaders_with_index(path='../data', batch_size=batch_size, num_labeled=self.num_labeled, which_dataset=dataset, validation=False) print('Labeled samples: ' + str(len(self.train_loader.sampler))) self.batch_size = self.train_loader.batch_size self.device = torch.device( "cuda:0" if torch.cuda.is_available() else "cpu") print(self.device) depth, k, n_out = model_params self.model = WideResNet(depth=depth, k=k, n_out=n_out, bias=True).to(self.device) self.ema_model = WideResNet(depth=depth, k=k, n_out=n_out, bias=True).to(self.device) for param in self.ema_model.parameters(): param.detach_() if optimizer == 'adam': self.lr, self.weight_decay = adam self.momentum, self.lr_decay = None, None self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr) self.ema_optimizer = WeightEMA(self.model, self.ema_model, self.lr, alpha=0.999) else: self.lr, self.momentum, self.weight_decay, self.lr_decay = sgd self.optimizer = optim.SGD(self.model.parameters(), lr=self.lr, momentum=self.momentum, weight_decay=self.weight_decay, nesterov=True) self.ema_optimizer = None self.criterion = nn.CrossEntropyLoss() self.train_accuracies, self.train_losses, = [], [] self.val_accuracies, self.val_losses, = [], [] self.best_acc = 0 self.augment = Augment(K=1) self.path = save_path self.writer = SummaryWriter() def train(self): iter_train_loader = iter(self.train_loader) for step in range(self.n_steps): self.model.train() # Get next batch of data try: x_input, x_labels, _ = iter_train_loader.next() # Check if batch size has been cropped for last batch if x_input.shape[0] < self.batch_size: iter_train_loader = iter(self.train_loader) x_input, x_labels, _ = iter_train_loader.next() except: iter_train_loader = iter(self.train_loader) x_input, x_labels, _ = iter_train_loader.next() # Send to GPU x_input = x_input.to(self.device) x_labels = x_labels.to(self.device) # Augment x_input = self.augment(x_input) x_input = x_input.reshape((-1, 3, 32, 32)) # Compute X' predictions x_output = self.model(x_input) # Compute loss loss = self.criterion(x_output, x_labels) # Step self.optimizer.zero_grad() loss.backward() self.optimizer.step() if self.ema_optimizer: self.ema_optimizer.step() # Decaying learning rate. Used in with SGD Nesterov optimizer if not self.ema_optimizer and step in self.lr_decay: for g in self.optimizer.param_groups: g['lr'] *= 0.2 # Evaluate model self.model.eval() if step > 0 and not step % self.steps_validation: val_acc, is_best = self.evaluate_loss_acc(step) if is_best: self.save_model(step=step, path=f'{self.path}/best_checkpoint.pt') # Save checkpoint if step > 10000 and not step % self.steps_checkpoint: self.save_model(step=step, path=f'{self.path}/checkpoint_{step}.pt') # --- Training finished --- test_val, test_acc = self.evaluate(self.test_loader) print("Training done!!\t Test loss: %.3f \t Test accuracy: %.3f" % (test_val, test_acc)) self.writer.flush() # --- support functions --- def evaluate_loss_acc(self, step): val_loss, val_acc = self.evaluate(self.val_loader) self.val_losses.append(val_loss) self.val_accuracies.append(val_acc) train_loss, train_acc = self.evaluate(self.train_loader) self.train_losses.append(train_loss) self.train_accuracies.append(train_acc) is_best = False if val_acc > self.best_acc: self.best_acc = val_acc is_best = True print( "Step %d.\tLoss train_lbl/valid %.2f %.2f\t Accuracy train_lbl/valid %.2f %.2f \tBest acc %.2f \t%s" % (step, train_loss, val_loss, train_acc, val_acc, self.best_acc, time.ctime())) self.writer.add_scalar("Loss train_label", train_loss, step) self.writer.add_scalar("Loss validation", val_loss, step) self.writer.add_scalar("Accuracy train_label", train_acc, step) self.writer.add_scalar("Accuracy validation", val_acc, step) return val_acc, is_best def evaluate(self, dataloader): correct, total, loss = 0, 0, 0 with torch.no_grad(): for i, data in enumerate(dataloader, 0): inputs, labels = data[0].to(self.device), data[1].to( self.device) if self.ema_optimizer: outputs = self.ema_model(inputs) else: outputs = self.model(inputs) loss += self.criterion(outputs, labels).item() _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() loss /= dataloader.__len__() acc = correct / total * 100 return loss, acc def save_model(self, step=None, path='../models/model.pt'): if not step: step = self.n_steps # Training finished torch.save( { 'step': step, 'model_state_dict': self.model.state_dict(), 'ema_state_dict': self.ema_model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'loss_train': self.train_losses, 'loss_val': self.val_losses, 'acc_train': self.train_accuracies, 'acc_val': self.val_accuracies, 'steps': self.n_steps, 'batch_size': self.batch_size, 'num_labels': self.num_labeled, 'lr': self.lr, 'weight_decay': self.weight_decay, 'momentum': self.momentum, 'lr_decay': self.lr_decay, 'lbl_idx': self.lbl_idx, 'val_idx': self.val_idx, }, path) def load_checkpoint(self, model_name): saved_model = torch.load(f'../models/{model_name}') self.model.load_state_dict(saved_model['model_state_dict']) self.ema_model.load_state_dict(saved_model['ema_state_dict']) self.optimizer.load_state_dict(saved_model['optimizer_state_dict']) self.start_step = saved_model['step'] self.train_loader, _, self.val_loader, self.test_loader, self.lbl_idx, _, self.val_idx = \ get_dataloaders_with_index(path='../data', batch_size=self.batch_size, num_labeled=self.num_labeled, which_dataset='cifar10', lbl_idxs=saved_model['lbl_idx'], unlbl_idxs=[], valid_idxs=saved_model['val_idx']) print('Model ' + model_name + ' loaded.')
if __name__ == "__main__": from models.wideresnet import WideResNet # from models.WideResNet import WideResNet import torch # tensorboard writer tb_writer = SummaryWriter(log_dir=args.tbsw_logdir) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") K = args.K # steps: 2**20 ideal B = args.B # batch size: 64 ideal mu = args.mu strides = [1, 1, 2, 2] # model = WideResNet(d=28, k=3, n_classes=10, input_features=3, output_features=16, strides=strides) # prop unsup/sup: 7 ideal model = WideResNet(3, 28, 2, 10) # Creating Dataset labels_per_class = [args.labels_per_class for _ in range(10)] dataset_loader = CustomLoader(labels_per_class=labels_per_class, db_dir=args.root, db_name=args.use_database, mode=args.task, download=args.download) dataset_loader.load() unlabeled_dataset = DataSet(dataset_loader.get_set( dataset_loader.unlabeled_set), batch=B * mu, steps=K) labeled_dataset = DataSet(dataset_loader.get_set(
elif args.ood_dataset == 'all': for i in range(len(ood_dataset.datasets)): ood_loader.dataset.datasets[i].imgs = ood_loader.dataset.datasets[ i].imgs[1000:] ood_loader.dataset.cummulative_sizes = ood_loader.dataset.cumsum( ood_loader.dataset.datasets) else: ood_loader.dataset.imgs = ood_loader.dataset.imgs[1000:] ood_loader.dataset.__len__ = len(ood_loader.dataset.imgs) #------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ # Load pre-trained model #------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ if args.model == 'wideresnet': cnn = WideResNet(depth=28, num_classes=num_classes, widen_factor=10).cuda() elif args.model == 'wideresnet16_8': cnn = WideResNet(depth=16, num_classes=num_classes, widen_factor=8).cuda() elif args.model == 'densenet': cnn = DenseNet3(depth=100, num_classes=num_classes, growth_rate=12, reduction=0.5).cuda() elif args.model == 'vgg13': cnn = VGG(vgg_name='VGG13', num_classes=num_classes).cuda() model_dict = cnn.state_dict() pretrained_dict = torch.load('checkpoints/' + filename + '.pt') cnn.load_state_dict(pretrained_dict) cnn = cnn.cuda()
tensors, labels = [torch.FloatTensor(np.array(x[0]).transpose(2, 0, 1))/255 for x in ims], [x[1] for x in ims] s = torch.stack(tensors) return s, torch.LongTensor(labels) if __name__ == "__main__": from models.wideresnet import WideResNet # from models.WideResNet import WideResNet import torch device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # model = WideResNet(d=28, k=3, n_classes=10, input_features=3, output_features=16, strides=strides) # prop unsup/sup: 7 ideal model = WideResNet(3, 28, 2, 10) # Creating Dataset labels_per_class = [100 for _ in range(10)] dataset_loader = CustomLoader(labels_per_class = labels_per_class, db_dir = args.root, db_name = args.use_database, mode = args.task, download = args.download) dataset_loader.load() test_data = DataSet(dataset_loader.database, batch = 1) test_loader = DataLoader(test_data, collate_fn=default_collate_fn, batch_size=64) model_directory = os.path.expanduser(args.save_dir) model = WideResNet(3, 28, 2, 10) state_dict = torch.load(os.path.join(model_directory, args.name_model_specs))
plot_histograms(correct, confidence) val_acc = np.mean(correct) conf_min = np.min(confidence) conf_max = np.max(confidence) conf_avg = np.mean(confidence) cnn.train() return val_acc, conf_min, conf_max, conf_avg #------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ # Model, Dataset, Optizimer, Scheduler, csv_logger, lambda の設定 #------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ if args.model == 'wideresnet': cnn = WideResNet(depth=28, num_classes=num_classes, widen_factor=10).cuda() elif args.model == 'wideresnet16_8': cnn = WideResNet(depth=16, num_classes=num_classes, widen_factor=8).cuda() elif args.model == 'densenet': cnn = DenseNet3(depth=100, num_classes=num_classes, growth_rate=12, reduction=0.5).cuda() elif args.model == 'vgg13': cnn = VGG(vgg_name='VGG13', num_classes=num_classes).cuda() prediction_criterion = nn.NLLLoss().cuda() cnn_optimizer = torch.optim.SGD(cnn.parameters(), lr=args.learning_rate, momentum=0.9,
def get_model(config, num_class=10, bn_types=None, data_parallel=True): name = config.model print('model name: {}'.format(name)) print('bn_types: {}'.format(bn_types)) if name == 'resnet50': if bn_types is None: model = ResNet(dataset='imagenet', depth=50, num_classes=num_class, bottleneck=True) else: model = ResNetMultiBN(dataset='imagenet', depth=50, num_classes=num_class, bn_types=bn_types, bottleneck=True) elif name == 'resnet200': if bn_types is None: model = ResNet(dataset='imagenet', depth=200, num_classes=num_class, bottleneck=True) else: model = ResNetMultiBN(dataset='imagenet', depth=200, num_classes=num_class, bn_types=bn_types, bottleneck=True) elif name == 'wresnet40_2': if bn_types is None: model = WideResNet(40, 2, dropout_rate=0.0, num_classes=num_class) else: raise Exception('unimplemented error') elif name == 'wresnet28_10': if bn_types is None: model = WideResNet(28, 10, dropout_rate=0.0, num_classes=num_class) else: model = WideResNetMultiBN(28, 10, dropout_rate=0.0, num_classes=num_class, bn_types=bn_types) elif name == 'shakeshake26_2x32d': if bn_types is None: model = ShakeResNet(26, 32, num_class) else: model = ShakeResNetMultiBN(26, 32, num_class, bn_types) elif name == 'shakeshake26_2x64d': if bn_types is None: model = ShakeResNet(26, 64, num_class) else: model = ShakeResNetMultiBN(26, 64, num_class, bn_types) elif name == 'shakeshake26_2x96d': if bn_types is None: model = ShakeResNet(26, 96, num_class) else: model = ShakeResNetMultiBN(26, 96, num_class, bn_types) elif name == 'shakeshake26_2x112d': if bn_types is None: model = ShakeResNet(26, 112, num_class) else: model = ShakeResNetMultiBN(26, 112, num_class, bn_types) elif name == 'shakeshake26_2x96d_next': if bn_types is None: model = ShakeResNeXt(26, 96, 4, num_class) else: raise Exception('unimplemented error') elif name == 'pyramid': if bn_types is None: model = PyramidNet('cifar10', depth=config.pyramidnet_depth, alpha=config.pyramidnet_alpha, num_classes=num_class, bottleneck=True) else: model = PyramidNetMultiBN('cifar10', depth=config.pyramidnet_depth, alpha=config.pyramidnet_alpha, num_classes=num_class, bottleneck=True, bn_types=bn_types) else: raise NameError('no model named, %s' % name) if data_parallel: model = model.cuda() model = DataParallel(model) else: import horovod.torch as hvd device = torch.device('cuda', hvd.local_rank()) model = model.to(device) cudnn.benchmark = True return model
def __init__(self, batch_size, num_lbls, model_params, n_steps, K, lambda_u, optimizer, adam, sgd, steps_validation, steps_checkpoint, dataset, save_path, use_pseudo, tau): self.validation_set = False self.n_steps = n_steps self.start_step = 0 self.K = K self.steps_validation = steps_validation self.steps_checkpoint = steps_checkpoint self.num_labeled = num_lbls self.labeled_loader, self.unlabeled_loader, self.val_loader, self.test_loader, self.lbl_idx, self.unlbl_idx, self.val_idx \ = get_dataloaders_with_index(path='../data', batch_size=batch_size, num_labeled=num_lbls, which_dataset=dataset, validation=self.validation_set) print('Labeled samples: ' + str(len(self.labeled_loader.sampler)) + '\tUnlabeled samples: ' + str(len(self.unlabeled_loader.sampler))) self.targets_list = np.array(self.labeled_loader.dataset.targets) self.batch_size = self.labeled_loader.batch_size self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print(self.device) # -- Model -- depth, k, n_out = model_params self.model = WideResNet(depth=depth, k=k, n_out=n_out, bias=False).to(self.device) self.ema_model = WideResNet(depth=depth, k=k, n_out=n_out, bias=False).to(self.device) for param in self.ema_model.parameters(): param.detach_() if optimizer == 'adam': self.lr, self.weight_decay = adam self.momentum, self.lr_decay = None, None self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr) self.ema_optimizer = WeightEMA(self.model, self.ema_model, self.lr, alpha=0.999) else: self.lr, self.momentum, self.weight_decay, self.lr_decay = sgd self.optimizer = optim.SGD(self.model.parameters(), lr=self.lr, momentum=self.momentum, weight_decay=self.weight_decay, nesterov=True) self.ema_optimizer = None self.lambda_u_max, self.step_top_up = lambda_u self.loss_mixmatch = Loss(self.lambda_u_max, self.step_top_up) self.criterion = nn.CrossEntropyLoss() self.train_accuracies, self.train_losses, = [], [] self.val_accuracies, self.val_losses, = [], [] self.best_acc = 0 self.mixmatch = MixMatch(self.model, self.batch_size, self.device) self.writer = SummaryWriter() self.path = save_path # -- Pseudo label -- self.use_pseudo = use_pseudo self.steps_pseudo_lbl = 5000 self.tau = tau # confidence threshold self.min_unlbl_samples = 1000 # Make a deep copy of original unlabeled loader _, self.unlabeled_loader_original, _, _, _, _, _ \ = get_dataloaders_with_index(path='../data', batch_size=batch_size, num_labeled=num_lbls, which_dataset=dataset, lbl_idxs=self.lbl_idx, unlbl_idxs=self.unlbl_idx, valid_idxs=self.val_idx, validation=self.validation_set)
def main(): global args, best_prec1 args = parser.parse_args() # torch.cuda.set_device(args.gpu) if args.tensorboard: print("Using tensorboard") configure("exp/%s" % (args.name)) # Data loading code if args.augment: print( "Doing image augmentation with\n" "Zoom: prob: {zoom_prob} range: {zoom_range}\n" "Stretch: prob: {stretch_prob} range: {stretch_range}\n" "Rotation: prob: {rotation_prob} range: {rotation_degree}".format( zoom_prob=args.zoom_prob, zoom_range=args.zoom_range, stretch_prob=args.stretch_prob, stretch_range=args.stretch_range, rotation_prob=args.rotation_prob, rotation_degree=args.rotation_degree)) transform_train = transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: F.pad( Variable(x.unsqueeze(0), requires_grad=False, volatile=True), (4, 4, 4, 4), mode='replicate').data.squeeze()), transforms.ToPILImage(), transforms.RandomCrop(32), transforms.RandomHorizontalFlip(), transforms.RandomRotation(prob=args.rotation_prob, degree=args.rotation_degree), transforms.RandomZoom(prob=args.zoom_prob, zoom_range=args.zoom_range), transforms.RandomStretch(prob=args.stretch_prob, stretch_range=args.stretch_range), transforms.ToTensor(), ]) else: transform_train = transforms.Compose([ transforms.ToTensor(), ]) transform_test = transforms.Compose([ transforms.ToTensor(), ]) kwargs = {'num_workers': 1, 'pin_memory': True} assert (args.dataset == 'cifar10' or args.dataset == 'cifar100') train_loader = torch.utils.data.DataLoader( datasets.__dict__[args.dataset.upper()]('../data', train=True, download=True, transform=transform_train), batch_size=args.batch_size, shuffle=True, **kwargs) val_loader = torch.utils.data.DataLoader( datasets.__dict__[args.dataset.upper()]('../data', train=False, transform=transform_test), batch_size=args.batch_size, shuffle=True, **kwargs) # create model model = WideResNet(args.layers, args.dataset == 'cifar10' and 10 or 100, args.widen_factor, dropRate=args.droprate) # get the number of model parameters print('Number of model parameters: {}'.format( sum([p.data.nelement() for p in model.parameters()]))) # for training on multiple GPUs. # Use CUDA_VISIBLE_DEVICES=0,1 to specify which GPUs to use # model = torch.nn.DataParallel(model).cuda() model = model.cuda() # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) cudnn.benchmark = True # define loss function (criterion) and optimizer criterion = nn.CrossEntropyLoss().cuda() optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, nesterov=args.nesterov, weight_decay=args.weight_decay) for epoch in range(args.start_epoch, args.epochs): adjust_learning_rate(optimizer, epoch + 1) # train for one epoch train(train_loader, model, criterion, optimizer, epoch) # evaluate on validation set prec1 = validate(val_loader, model, criterion, epoch) # remember best prec@1 and save checkpoint is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_prec1': best_prec1, }, is_best) print 'Best accuracy: ', best_prec1
'AllConv_PGP': lambda args: FuseTrainWrapper( AllConv_PGP(args.nclasses)), 'PreResNet164': lambda args: L.Classifier( PreResNet(164, args.nclasses)), 'PreResNet164_DConv': lambda args: L.Classifier( PreResNet_DConv(164, args.nclasses)), 'PreResNet164_PGP': lambda args: FuseTrainWrapper( PreResNet_PGP(164, args.nclasses)), 'DenseNetBC100': lambda args: L.Classifier( DenseNetBC(args.nclasses, (16, 16, 16), 12)), 'DenseNetBC100_DConv': lambda args: L.Classifier( DenseNetBC_DConv(args.nclasses, (16, 16, 16), 12)), 'DenseNetBC100_PGP': lambda args: FuseTrainWrapper( DenseNetBC_PGP(args.nclasses, (16, 16, 16), 12)), 'WideResNet28-10': lambda args: L.Classifier( WideResNet(28, args.nclasses, 10)), 'WideResNet28-10_DConv': lambda args: L.Classifier( WideResNet_DConv(28, args.nclasses, 10)), 'WideResNet28-10_PGP': lambda args: FuseTrainWrapper( WideResNet_PGP(28, args.nclasses, 10)), 'ResNeXt29_8x64d': lambda args: L.Classifier( ResNeXt(29, args.nclasses)), 'ResNeXt29_8x64d_DConv': lambda args: L.Classifier( ResNeXt_DConv(29, args.nclasses)), 'ResNeXt29_8x64d_PGP': lambda args: FuseTrainWrapper( ResNeXt_PGP(29, args.nclasses)), 'PyramidNetB164': lambda args: L.Classifier( PyramidNet(164, args.nclasses)), 'PyramidNetB164_DConv': lambda args: L.Classifier( PyramidNet_DConv(164, args.nclasses)), 'PyramidNetB164_PGP': lambda args: FuseTrainWrapper(
growth_rate=12, n_class=args.nclasses, in_ch=24, block=3, bottleneck=True, reduction=0.5)), 'DenseNetBC100': lambda args: L.Classifier(DenseNetBC(args.nclasses, (16, 16, 16), 12)), 'DenseNetBC100_DConv': lambda args: L.Classifier(DenseNetBC_DConv(args.nclasses, (16, 16, 16), 12)), 'DenseNetBC100_PGP': lambda args: FuseTrainWrapper( DenseNetBC_PGP(args.nclasses, (16, 16, 16), 12)), 'WideResNet28-10': lambda args: L.Classifier(WideResNet(28, args.nclasses, 10)), 'WideResNet28-10_DConv': lambda args: L.Classifier(WideResNet_DConv(28, args.nclasses, 10)), 'WideResNet28-10_PGP': lambda args: FuseTrainWrapper(WideResNet_PGP(28, args.nclasses, 10)), 'ResNeXt29_8x64d': lambda args: L.Classifier(ResNeXt(29, args.nclasses)), 'ResNeXt29_8x64d_DConv': lambda args: L.Classifier(ResNeXt_DConv(29, args.nclasses)), 'ResNeXt29_8x64d_PGP': lambda args: FuseTrainWrapper(ResNeXt_PGP(29, args.nclasses)), 'PyramidNetB164': lambda args: L.Classifier(PyramidNet(164, args.nclasses)), 'PyramidNetB164_DConv': lambda args: L.Classifier(PyramidNet_DConv(164, args.nclasses)), 'PyramidNetB164_PGP':
class MixMatchTrainer: def __init__(self, batch_size, num_lbls, model_params, n_steps, K, lambda_u, optimizer, adam, sgd, steps_validation, steps_checkpoint, dataset, save_path, use_pseudo, tau): self.validation_set = False self.n_steps = n_steps self.start_step = 0 self.K = K self.steps_validation = steps_validation self.steps_checkpoint = steps_checkpoint self.num_labeled = num_lbls self.labeled_loader, self.unlabeled_loader, self.val_loader, self.test_loader, self.lbl_idx, self.unlbl_idx, self.val_idx \ = get_dataloaders_with_index(path='../data', batch_size=batch_size, num_labeled=num_lbls, which_dataset=dataset, validation=self.validation_set) print('Labeled samples: ' + str(len(self.labeled_loader.sampler)) + '\tUnlabeled samples: ' + str(len(self.unlabeled_loader.sampler))) self.targets_list = np.array(self.labeled_loader.dataset.targets) self.batch_size = self.labeled_loader.batch_size self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print(self.device) # -- Model -- depth, k, n_out = model_params self.model = WideResNet(depth=depth, k=k, n_out=n_out, bias=False).to(self.device) self.ema_model = WideResNet(depth=depth, k=k, n_out=n_out, bias=False).to(self.device) for param in self.ema_model.parameters(): param.detach_() if optimizer == 'adam': self.lr, self.weight_decay = adam self.momentum, self.lr_decay = None, None self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr) self.ema_optimizer = WeightEMA(self.model, self.ema_model, self.lr, alpha=0.999) else: self.lr, self.momentum, self.weight_decay, self.lr_decay = sgd self.optimizer = optim.SGD(self.model.parameters(), lr=self.lr, momentum=self.momentum, weight_decay=self.weight_decay, nesterov=True) self.ema_optimizer = None self.lambda_u_max, self.step_top_up = lambda_u self.loss_mixmatch = Loss(self.lambda_u_max, self.step_top_up) self.criterion = nn.CrossEntropyLoss() self.train_accuracies, self.train_losses, = [], [] self.val_accuracies, self.val_losses, = [], [] self.best_acc = 0 self.mixmatch = MixMatch(self.model, self.batch_size, self.device) self.writer = SummaryWriter() self.path = save_path # -- Pseudo label -- self.use_pseudo = use_pseudo self.steps_pseudo_lbl = 5000 self.tau = tau # confidence threshold self.min_unlbl_samples = 1000 # Make a deep copy of original unlabeled loader _, self.unlabeled_loader_original, _, _, _, _, _ \ = get_dataloaders_with_index(path='../data', batch_size=batch_size, num_labeled=num_lbls, which_dataset=dataset, lbl_idxs=self.lbl_idx, unlbl_idxs=self.unlbl_idx, valid_idxs=self.val_idx, validation=self.validation_set) def train(self): iter_labeled_loader = iter(self.labeled_loader) iter_unlabeled_loader = iter(self.unlabeled_loader) for step in range(self.start_step, self.n_steps): # Get next batch of data self.model.train() try: x_imgs, x_labels, _ = iter_labeled_loader.next() # Check if batch size has been cropped for last batch if x_imgs.shape[0] < self.batch_size: iter_labeled_loader = iter(self.labeled_loader) x_imgs, x_labels, _ = iter_labeled_loader.next() except: iter_labeled_loader = iter(self.labeled_loader) x_imgs, x_labels, _ = iter_labeled_loader.next() try: u_imgs, _, _ = iter_unlabeled_loader.next() if u_imgs.shape[0] < self.batch_size: iter_unlabeled_loader = iter(self.unlabeled_loader) u_imgs, _, _ = iter_unlabeled_loader.next() except: iter_unlabeled_loader = iter(self.unlabeled_loader) u_imgs, _, _ = iter_unlabeled_loader.next() # Send to GPU x_imgs = x_imgs.to(self.device) x_labels = x_labels.to(self.device) u_imgs = u_imgs.to(self.device) # MixMatch algorithm x, u = self.mixmatch.run(x_imgs, x_labels, u_imgs) x_input, x_targets = x u_input, u_targets = u u_targets.detach_() # stop gradients from propagation to label guessing # Compute X' predictions x_output = self.model(x_input) # Compute U' predictions. Separate in batches u_batch_outs = [] for k in range(self.K): u_batch = u_input[k * self.batch_size:(k + 1) * self.batch_size] u_batch_outs.append(self.model(u_batch)) u_outputs = torch.cat(u_batch_outs, dim=0) # Compute loss loss = self.loss_mixmatch(x_output, x_targets, u_outputs, u_targets, step) # Step self.optimizer.zero_grad() loss.backward() self.optimizer.step() if self.ema_optimizer: self.ema_optimizer.step() # Decaying learning rate. Used in with SGD Nesterov optimizer if not self.ema_optimizer and step in self.lr_decay: for g in self.optimizer.param_groups: g['lr'] *= 0.2 # Evaluate model self.model.eval() if step > 0 and not step % self.steps_validation: val_acc, is_best = self.evaluate_loss_acc(step) if is_best and step > 10000: self.save_model(step=step, path=f'{self.path}/best_checkpoint.pt') # Save checkpoint if step > 0 and not step % self.steps_checkpoint: self.save_model(step=step, path=f'{self.path}/checkpoint_{step}.pt') # Generate Pseudo-labels if self.use_pseudo and step >= 50000 and not step % self.steps_pseudo_lbl: # matrix columns: [index, confidence, pseudo_label, true_label, is_ground_truth] matrix = self.get_pseudo_labels() self.print_threshold_comparison(matrix) # Generate pseudo set based on threshold (same for all classes) if self.tau != -1: matrix = self.generate_pseudo_set(matrix) # Generate pseudo set balanced (top 90% guesses of each class) else: matrix = self.generate_pseudo_set_balanced(matrix) iter_labeled_loader = iter(self.labeled_loader) iter_unlabeled_loader = iter(self.unlabeled_loader) # Save torch.save(matrix, f'{self.path}/pseudo_matrix_balanced_{step}.pt') # --- Training finished --- test_val, test_acc = self.evaluate(self.test_loader) print("Training done!!\t Test loss: %.3f \t Test accuracy: %.3f" % (test_val, test_acc)) self.writer.flush() # --- support functions --- def evaluate_loss_acc(self, step): val_loss, val_acc = self.evaluate(self.val_loader) self.val_losses.append(val_loss) self.val_accuracies.append(val_acc) train_loss, train_acc = self.evaluate(self.labeled_loader) self.train_losses.append(train_loss) self.train_accuracies.append(train_acc) is_best = False if val_acc > self.best_acc: self.best_acc = val_acc is_best = True print("Step %d.\tLoss train_lbl/valid %.2f %.2f\t Accuracy train_lbl/valid %.2f %.2f \tBest acc %.2f \t%s" % (step, train_loss, val_loss, train_acc, val_acc, self.best_acc, time.ctime())) self.writer.add_scalar("Loss train_label", train_loss, step) self.writer.add_scalar("Loss validation", val_loss, step) self.writer.add_scalar("Accuracy train_label", train_acc, step) self.writer.add_scalar("Accuracy validation", val_acc, step) return val_acc, is_best def evaluate(self, dataloader): correct, total, loss = 0, 0, 0 with torch.no_grad(): for i, data in enumerate(dataloader, 0): inputs, labels = data[0].to(self.device), data[1].to(self.device) if self.ema_optimizer: outputs = self.ema_model(inputs) else: outputs = self.model(inputs) loss += self.criterion(outputs, labels).item() _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() loss /= dataloader.__len__() acc = correct / total * 100 return loss, acc def get_losses(self): return self.loss_mixmatch.loss_list, self.loss_mixmatch.lx_list, self.loss_mixmatch.lu_list, self.loss_mixmatch.lu_weighted_list def save_model(self, step=None, path=f'../model.pt'): loss_list, lx, lu, lu_weighted = self.get_losses() if not step: step = self.n_steps # Training finished torch.save({ 'step': step, 'model_state_dict': self.model.state_dict(), 'ema_state_dict': self.ema_model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'loss_train': self.train_losses, 'loss_val': self.val_losses, 'acc_train': self.train_accuracies, 'acc_val': self.val_accuracies, 'loss_batch': loss_list, 'lx': lx, 'lu': lu, 'lu_weighted': lu_weighted, 'steps': self.n_steps, 'batch_size': self.batch_size, 'num_labels': self.num_labeled, 'lambda_u_max': self.lambda_u_max, 'step_top_up': self.step_top_up, 'lr': self.lr, 'weight_decay': self.weight_decay, 'momentum': self.momentum, 'lr_decay': self.lr_decay, 'lbl_idx': self.lbl_idx, 'unlbl_idx': self.unlbl_idx, 'val_idx': self.val_idx, }, path) def load_checkpoint(self, path_checkpoint): saved_model = torch.load(path_checkpoint) self.model.load_state_dict(saved_model['model_state_dict']) self.ema_model.load_state_dict(saved_model['ema_state_dict']) self.optimizer.load_state_dict(saved_model['optimizer_state_dict']) self.start_step = saved_model['step'] self.train_losses = saved_model['loss_train'] self.val_losses = saved_model['loss_val'] self.train_accuracies = saved_model['acc_train'] self.val_accuracies = saved_model['acc_val'] self.batch_size = saved_model['batch_size'] self.num_labeled = saved_model['num_labels'] self.labeled_loader, self.unlabeled_loader, self.val_loader, self.test_loader, self.lbl_idx, self.unlbl_idx, self.val_idx = \ get_dataloaders_with_index(path='../data', batch_size=self.batch_size, num_labeled=self.num_labeled, which_dataset='cifar10', lbl_idxs=saved_model['lbl_idx'], unlbl_idxs=saved_model['unlbl_idx'], valid_idxs=saved_model['val_idx'], validation=self.validation_set) self.unlabeled_loader_original = self.unlabeled_loader print('Model ' + path_checkpoint + ' loaded.') def get_pseudo_labels(self): matrix = torch.tensor([], device=self.device) # Iterate through unlabeled loader for batch_idx, (data, target, idx) in enumerate(self.unlabeled_loader_original): with torch.no_grad(): # Get predictions for unlabeled samples out = self.model(data.to(self.device)) p_out = torch.softmax(out, dim=1) # turn into probability distribution confidence, pseudo_lbl = torch.max(p_out, dim=1) pseudo_lbl_batch = torch.vstack((idx.to(self.device), confidence, pseudo_lbl)).T # Append to matrix matrix = torch.cat((matrix, pseudo_lbl_batch), dim=0) # (n_unlabeled, 3) n_unlabeled = matrix.shape[0] indices = matrix[:, 0].cpu().numpy().astype(int) ground_truth = self.targets_list[indices] matrix = torch.vstack((matrix.T, torch.tensor(ground_truth, device=self.device))).T # (n_unlabeled, 4) matrix = torch.vstack((matrix.T, torch.zeros(n_unlabeled, device=self.device))).T # (n_unlabeled, 5) # matrix columns: [index, confidence, pseudo_label, true_label, is_ground_truth] # Check if pseudo label is ground truth for i in range(n_unlabeled): if matrix[i, 2] == matrix[i, 3]: matrix[i, 4] = 1 return matrix def generate_pseudo_set(self, matrix): unlbl_mask1 = (matrix[:, 1] < self.tau) # unlbl_mask2 = (matrix[:, 1] >= 0.99) pseudo_mask = (matrix[:, 1] >= self.tau) # pseudo_mask = (matrix[:, 1] >= self.tau) & (matrix[:, 1] < 0.99) # unlbl_indices = torch.cat((matrix[unlbl_mask1, 0], matrix[unlbl_mask2, 0])) unlbl_indices = matrix[unlbl_mask1, 0] matrix = matrix[pseudo_mask, :] indices = matrix[:, 0] new_lbl_idx = np.int_(torch.cat((torch.tensor(self.lbl_idx, device=self.device), indices)).tolist()) new_unlbl_idx = np.int_(unlbl_indices.tolist()) self.labeled_loader, self.unlabeled_loader, self.val_loader, self.test_loader, _, _, new_val_idx = \ get_dataloaders_with_index(path='../data', batch_size=self.batch_size, num_labeled=self.num_labeled, which_dataset='cifar10', lbl_idxs=new_lbl_idx, unlbl_idxs=new_unlbl_idx, valid_idxs=self.val_idx, validation=self.validation_set) assert np.allclose(self.val_idx, new_val_idx), 'error' assert (len(self.labeled_loader.sampler) + len(self.unlabeled_loader.sampler) == 50000), 'error' # Change real labels for pseudo labels for i in range(matrix.shape[0]): index = int(matrix[i, 0].item()) assert int(matrix[i, 3]) == self.labeled_loader.dataset.targets[index] pseudo_labels = int(matrix[i, 2].item()) self.labeled_loader.dataset.targets[index] = pseudo_labels correct = torch.sum(matrix[:, 4]).item() pseudo_acc = correct / matrix.shape[0] * 100 if matrix.shape[0] > 0 else 0 print('Generated labels: %d\t Correct: %d\t Accuracy: %.2f' % (matrix.shape[0], correct, pseudo_acc)) print('Training with Labeled / Unlabeled / Validation samples\t %d %d %d' % (len(new_lbl_idx), len(new_unlbl_idx), len(self.val_idx))) return matrix def generate_pseudo_set_balanced(self, matrix_all): unlbl_indices = torch.tensor([], device=self.device) # Get top 10% confident guesses for each class matrix = torch.tensor([], device=self.device) for i in range(10): matrix_label = matrix_all[matrix_all[:, 2] == i, :] threshold = torch.quantile(matrix_label[:, 1], 0.9) # returns prob in the percentile 90 unlbl_idxs = matrix_label[matrix_label[:, 1] < threshold, 0] matrix_label = matrix_label[matrix_label[:, 1] >= threshold, :] matrix = torch.cat((matrix, matrix_label), dim=0) unlbl_indices = torch.cat((unlbl_indices, unlbl_idxs)) indices = matrix[:, 0] new_lbl_idx = np.int_(torch.cat((torch.tensor(self.lbl_idx, device=self.device), indices)).tolist()) new_unlbl_idx = np.int_(unlbl_indices.tolist()) self.labeled_loader, self.unlabeled_loader, self.val_loader, self.test_loader, _, _, new_val_idx = \ get_dataloaders_with_index(path='../data', batch_size=self.batch_size, num_labeled=self.num_labeled, which_dataset='cifar10', lbl_idxs=new_lbl_idx, unlbl_idxs=new_unlbl_idx, valid_idxs=self.val_idx, validation=self.validation_set) assert np.allclose(self.val_idx, new_val_idx), 'error' assert (len(self.labeled_loader.sampler) + len(self.unlabeled_loader.sampler) == 50000), 'error' # Change real labels for pseudo labels for i in range(matrix.shape[0]): index = int(matrix[i, 0].item()) assert int(matrix[i, 3]) == self.labeled_loader.dataset.targets[index] pseudo_labels = int(matrix[i, 2].item()) self.labeled_loader.dataset.targets[index] = pseudo_labels correct = torch.sum(matrix[:, 4]).item() pseudo_acc = correct / matrix.shape[0] * 100 if matrix.shape[0] > 0 else 0 print('Generated labels: %d\t Correct: %d\t Accuracy: %.2f' % (matrix.shape[0], correct, pseudo_acc)) print('Training with Labeled / Unlabeled / Validation samples\t %d %d %d' % (len(new_lbl_idx), len(new_unlbl_idx), len(self.val_idx))) return matrix def print_threshold_comparison(self, matrix): m2 = matrix[matrix[:, 1] >= 0.9, :] for i, tau in enumerate([0.9, 0.95, 0.97, 0.99, 0.999]): pseudo_labels = m2[m2[:, 1] >= tau, :] total = pseudo_labels.shape[0] correct = torch.sum(pseudo_labels[:, 4]).item() print('Confidence threshold %.3f\t Generated / Correct / Precision\t %d\t%d\t%.2f ' % (tau, total, correct, correct / (total + np.finfo(float).eps) * 100))
images = Variable(images, volatile=True) outputs = model(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum() print('Accuracy of the network on the %d test images: %d %%' % (total, 100 * correct / total)) if __name__ == '__main__': parser = argparse.ArgumentParser(description='PyTorch FeedForward Example') parser.add_argument('--epochs', type=int, default=3, help='number of epochs to train') parser.add_argument('--lr', type=float, default=0.01, help='learning rate') parser.add_argument('--batch_size', type=int, default=128, help='batch size') args = parser.parse_args() train_loader, test_loader = data_loader() model = WideResNet(int(28), 10) criterion = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), lr=args.lr) train(args.epochs) test()
def main(): parser = argparse.ArgumentParser() parser.add_argument('--stage', default='train', type=str) parser.add_argument('--dataset', default='imagenet', type=str) parser.add_argument('--lr', default=0.0012, type=float) parser.add_argument('--batch_size', default=128, type=int) parser.add_argument('--gpus', default='0,1,2,3', type=str) parser.add_argument('--weight_decay', default=1e-5, type=float) parser.add_argument('--max_epoch', default=30, type=int) parser.add_argument('--lr_decay_steps', default='15,20,25', type=str) parser.add_argument('--exp', default='', type=str) parser.add_argument('--list', default='', type=str) parser.add_argument('--resume_path', default='', type=str) parser.add_argument('--pretrain_path', default='', type=str) parser.add_argument('--n_workers', default=32, type=int) parser.add_argument('--network', default='resnet50', type=str) global args args = parser.parse_args() if not os.path.exists(args.exp): os.makedirs(args.exp) if not os.path.exists(os.path.join(args.exp, 'runs')): os.makedirs(os.path.join(args.exp, 'runs')) if not os.path.exists(os.path.join(args.exp, 'models')): os.makedirs(os.path.join(args.exp, 'models')) if not os.path.exists(os.path.join(args.exp, 'logs')): os.makedirs(os.path.join(args.exp, 'logs')) # logger initialize logger = getLogger(args.exp) device_ids = list(map(lambda x: int(x), args.gpus.split(','))) device = torch.device('cuda: 0') train_loader, val_loader = cifar.get_semi_dataloader( args) if args.dataset.startswith( 'cifar') else imagenet.get_semi_dataloader(args) # create model if args.network == 'alexnet': network = AlexNet(128) elif args.network == 'alexnet_cifar': network = AlexNet_cifar(128) elif args.network == 'resnet18_cifar': network = ResNet18_cifar() elif args.network == 'resnet50_cifar': network = ResNet50_cifar() elif args.network == 'wide_resnet28': network = WideResNet(28, args.dataset == 'cifar10' and 10 or 100, 2) elif args.network == 'resnet18': network = resnet18() elif args.network == 'resnet50': network = resnet50() network = nn.DataParallel(network, device_ids=device_ids) network.to(device) classifier = nn.Linear(2048, 1000).to(device) # create optimizer parameters = network.parameters() optimizer = torch.optim.SGD( parameters, lr=args.lr, momentum=0.9, weight_decay=args.weight_decay, ) cls_optimizer = torch.optim.SGD( classifier.parameters(), lr=args.lr * 50, momentum=0.9, weight_decay=args.weight_decay, ) cudnn.benchmark = True # create memory_bank global writer writer = SummaryWriter(comment='SemiSupervised', logdir=os.path.join(args.exp, 'runs')) # create criterion criterion = nn.CrossEntropyLoss() logging.info(beautify(args)) start_epoch = 0 if args.pretrain_path != '' and args.pretrain_path != 'none': logging.info('loading pretrained file from {}'.format( args.pretrain_path)) checkpoint = torch.load(args.pretrain_path) state_dict = checkpoint['state_dict'] valid_state_dict = { k: v for k, v in state_dict.items() if k in network.state_dict() and 'fc.' not in k } for k, v in network.state_dict().items(): if k not in valid_state_dict: logging.info('{}: Random Init'.format(k)) valid_state_dict[k] = v # logging.info(valid_state_dict.keys()) network.load_state_dict(valid_state_dict) else: logging.info('Training SemiSupervised Learning From Scratch') logging.info('start training') best_acc = 0.0 try: for i_epoch in range(start_epoch, args.max_epoch): train(i_epoch, network, classifier, criterion, optimizer, cls_optimizer, train_loader, device) checkpoint = { 'epoch': i_epoch + 1, 'state_dict': network.state_dict(), 'optimizer': optimizer.state_dict(), } torch.save(checkpoint, os.path.join(args.exp, 'models', 'checkpoint.pth')) adjust_learning_rate(args.lr_decay_steps, optimizer, i_epoch) if i_epoch % 2 == 0: acc1, acc5 = validate(i_epoch, network, classifier, val_loader, device) if acc1 >= best_acc: best_acc = acc1 torch.save(checkpoint, os.path.join(args.exp, 'models', 'best.pth')) writer.add_scalar('acc1', acc1, i_epoch + 1) writer.add_scalar('acc5', acc5, i_epoch + 1) if i_epoch in [30, 60, 120, 160, 200]: torch.save( checkpoint, os.path.join(args.exp, 'models', '{}.pth'.format(i_epoch + 1))) logging.info( colorful('[Epoch: {}] val acc: {:.4f}/{:.4f}'.format( i_epoch, acc1, acc5))) logging.info( colorful('[Epoch: {}] best acc: {:.4f}'.format( i_epoch, best_acc))) with torch.no_grad(): for name, param in network.named_parameters(): if 'bn' not in name: writer.add_histogram(name, param, i_epoch) # cluster except KeyboardInterrupt as e: logging.info('KeyboardInterrupt at {} Epochs'.format(i_epoch)) exit()
def load(config): if config.model == 'wideresnet': model = WideResNet(num_classes=config.dataset_classes) ema_model = WideResNet(num_classes=config.dataset_classes) else: model = CNN13(num_classes=config.dataset_classes) ema_model = CNN13(num_classes=config.dataset_classes) if config.semi_supervised == 'mix_match': semi_supervised = MixMatch(config) semi_supervised_loss = mix_match_loss elif config.semi_supervised == 'pseudo_label': semi_supervised = PseudoLabel(config) semi_supervised_loss = pseudo_label_loss model.to(config.device) ema_model.to(config.device) torch.backends.cudnn.benchmark = True optimizer = Adam(model.parameters(), lr=config.learning_rate) ema_optimizer = WeightEMA(model, ema_model, alpha=config.ema_decay) if config.resume: checkpoint = torch.load(config.checkpoint_path, map_location=config.device) model.load_state_dict(checkpoint['model_state']) ema_model.load_state_dict(checkpoint['ema_model_state']) optimizer.load_state_dict(checkpoint['optimizer_state']) # optimizer state should be moved to corresponding device for optimizer_state in optimizer.state.values(): for k, v in optimizer_state.items(): if isinstance(v, torch.Tensor): optimizer_state[k] = v.to(config.device) return model, ema_model, optimizer, ema_optimizer, semi_supervised, semi_supervised_loss
'--path_to_save', type=str, default='./results/', help='directory to save the results, must already exist') hps = parser.parse_args() hps.n_labels = 43 if hps.dataset == 'gts' else 10 hps.las = True if hps.las in ['True', 'true', '1'] else False if hps.eps == -1: hps.eps = 0.3 if hps.dataset == 'mnist' else 0.0314 assert hps.p == 'linf', 'Lp-norm not supported' ### load models and datasets one can get at https://github.com/yaodongyu/TRADES if hps.dataset == 'cifar10': from models.wideresnet import WideResNet device = torch.device("cuda") model = WideResNet().to(device) model.load_state_dict(torch.load('./checkpoints/model_cifar_wrn.pt')) model.eval() X_data = np.load('./data_attack/cifar10_X.npy') Y_data = np.load('./data_attack/cifar10_Y.npy') X_data = np.transpose(X_data, (0, 3, 1, 2)) elif hps.dataset == 'mnist': from models.small_cnn import SmallCNN device = torch.device("cuda") model = SmallCNN().to(device) model.load_state_dict( torch.load('./checkpoints/model_mnist_smallcnn.pt')) model.eval()