Exemplo n.º 1
0
def test():
    device = torch.device('cpu')
    arguments = load_test_arguments()
    checkpoints = {}

    if arguments.checkpoint_file:
        checkpoints[arguments.checkpoint_file] = arguments.checkpoint_file
    else:
        for root, dirs, files in os.walk(arguments.checkpoint_dir):
            for file in files:
                if file.endswith('.bin'):
                    checkpoints[file] = os.path.join(root, file)

    test_dataloader = load_test_data(arguments)

    for checkpoint_name in checkpoints:

        checkpoint = torch.load(checkpoints[checkpoint_name],
                                map_location=device)

        log_checkpoint(checkpoint)

        ema_model = WideResNet(num_classes=10)
        ema_model.load_state_dict(checkpoint['ema_model_state'])

        batches = len(test_dataloader)

        metrics = {'test_steps': 0, 'test_accuracy': 0}

        test_progress_bar = tqdm(enumerate(test_dataloader))
        for batch_step, batch in test_progress_bar:
            test_step(ema_model, batch, metrics, device)
            on_test_batch_end(batch_step, batches, metrics, test_progress_bar)
Exemplo n.º 2
0
class WeightEMA(object):
    def __init__(self, model, ema_model, run_type=0, alpha=0.999):
        self.model = model
        self.ema_model = ema_model
        self.alpha = alpha
        if run_type == 1:
            self.tmp_model = Classifier(num_classes=10).cuda()
        else:
            self.tmp_model = WideResNet(num_classes=10).cuda()
        self.wd = 0.02 * args.lr

        for param, ema_param in zip(self.model.parameters(),
                                    self.ema_model.parameters()):
            ema_param.data.copy_(param.data)

    def step(self, bn=False):
        if bn:
            # copy batchnorm stats to ema model
            for ema_param, tmp_param in zip(self.ema_model.parameters(),
                                            self.tmp_model.parameters()):
                tmp_param.data.copy_(ema_param.data.detach())

            self.ema_model.load_state_dict(self.model.state_dict())

            for ema_param, tmp_param in zip(self.ema_model.parameters(),
                                            self.tmp_model.parameters()):
                ema_param.data.copy_(tmp_param.data.detach())
        else:
            one_minus_alpha = 1.0 - self.alpha
            for param, ema_param in zip(self.model.parameters(),
                                        self.ema_model.parameters()):
                ema_param.data.mul_(self.alpha)
                ema_param.data.add_(param.data.detach() * one_minus_alpha)
                # customized weight decay
                param.data.mul_(1 - self.wd)
Exemplo n.º 3
0
    def __init__(self, batch_size, model_params, n_steps, optimizer, adam, sgd,
                 steps_validation, steps_checkpoint, dataset, save_path):

        self.n_steps = n_steps
        self.start_step = 0
        self.steps_validation = steps_validation
        self.steps_checkpoint = steps_checkpoint
        self.num_labeled = 50000
        self.train_loader, _, self.val_loader, self.test_loader, self.lbl_idx, _, self.val_idx = \
            get_dataloaders_with_index(path='../data',
                                       batch_size=batch_size,
                                       num_labeled=self.num_labeled,
                                       which_dataset=dataset,
                                       validation=False)
        print('Labeled samples: ' + str(len(self.train_loader.sampler)))

        self.batch_size = self.train_loader.batch_size
        self.device = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu")
        print(self.device)

        depth, k, n_out = model_params
        self.model = WideResNet(depth=depth, k=k, n_out=n_out,
                                bias=True).to(self.device)
        self.ema_model = WideResNet(depth=depth, k=k, n_out=n_out,
                                    bias=True).to(self.device)
        for param in self.ema_model.parameters():
            param.detach_()

        if optimizer == 'adam':
            self.lr, self.weight_decay = adam
            self.momentum, self.lr_decay = None, None
            self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
            self.ema_optimizer = WeightEMA(self.model,
                                           self.ema_model,
                                           self.lr,
                                           alpha=0.999)

        else:
            self.lr, self.momentum, self.weight_decay, self.lr_decay = sgd
            self.optimizer = optim.SGD(self.model.parameters(),
                                       lr=self.lr,
                                       momentum=self.momentum,
                                       weight_decay=self.weight_decay,
                                       nesterov=True)
            self.ema_optimizer = None

        self.criterion = nn.CrossEntropyLoss()

        self.train_accuracies, self.train_losses, = [], []
        self.val_accuracies, self.val_losses, = [], []
        self.best_acc = 0

        self.augment = Augment(K=1)

        self.path = save_path
        self.writer = SummaryWriter()
Exemplo n.º 4
0
    def __init__(self, model, ema_model, run_type=0, alpha=0.999):
        self.model = model
        self.ema_model = ema_model
        self.alpha = alpha
        if run_type == 1:
            self.tmp_model = Classifier(num_classes=10).cuda()
        else:
            self.tmp_model = WideResNet(num_classes=10).cuda()
        self.wd = 0.02 * args.lr

        for param, ema_param in zip(self.model.parameters(),
                                    self.ema_model.parameters()):
            ema_param.data.copy_(param.data)
Exemplo n.º 5
0
def get_model(name, num_classes=10, normalize_input=False, pretrain=False):
    name_parts = name.split('-')
    if name_parts[0] == 'wrn':
        depth = int(name_parts[1])
        widen = int(name_parts[2])
        if not pretrain:
            model = WideResNet(
                depth=depth, num_classes=num_classes, widen_factor=widen)
        else:
            model = WideResNetPre(
                depth=depth, num_classes=num_classes, widen_factor=widen)
        
    elif name_parts[0] == 'ss':
        model = ShakeNet(dict(depth=int(name_parts[1]),
                              base_channels=int(name_parts[2]),
                              shake_forward=True, shake_backward=True,
                              shake_image=True, input_shape=(1, 3, 32, 32),
                              n_classes=num_classes,
                              ))
    elif name_parts[0] == 'resnet':
        depth = int(name_parts[1])
        resnets = {18: ResNet18, 50: ResNet50, 152: ResNet152}
        if depth in resnets:
            # regular resnets
            model = resnets[depth]()
        else:
            # CIFAR - ResNet
            model = ResNet(num_classes=num_classes, depth=depth)
    else:
        raise ValueError('Could not parse model name %s' % name)

    if normalize_input:
        model = Sequential(NormalizeInput(), model)

    return model
Exemplo n.º 6
0
    def setup(self, flags):
        torch.backends.cudnn.deterministic = flags.deterministic
        print('torch.backends.cudnn.deterministic:', torch.backends.cudnn.deterministic)
        fix_all_seed(flags.seed)

        if flags.dataset == 'cifar10':
            num_classes = 10
        else:
            num_classes = 100

        if flags.model == 'densenet':
            self.network = densenet(num_classes=num_classes)
        elif flags.model == 'wrn':
            self.network = WideResNet(flags.layers, num_classes, flags.widen_factor, flags.droprate)
        elif flags.model == 'allconv':
            self.network = AllConvNet(num_classes)
        elif flags.model == 'resnext':
            self.network = resnext29(num_classes=num_classes)
        else:
            raise Exception('Unknown model.')
        self.network = self.network.cuda()

        print(self.network)
        print('flags:', flags)
        if not os.path.exists(flags.logs):
            os.makedirs(flags.logs)

        flags_log = os.path.join(flags.logs, 'flags_log.txt')
        write_log(flags, flags_log)
Exemplo n.º 7
0
def build_model(args):
    # model = ResNet32(args.dataset == 'cifar10' and 10 or 100)
    model = WideResNet(args.layers,
                       args.dataset == 'cifar10' and 10 or 100,
                       args.widen_factor,
                       dropRate=args.droprate)
    # weights_init(model)

    # print('Number of model parameters: {}'.format(
    #     sum([p.data.nelement() for p in model.params()])))

    if torch.cuda.is_available():
        model.cuda()
        torch.backends.cudnn.benchmark = True

    return model
Exemplo n.º 8
0
def get_model(model, args):
    if model == 'alexnet':
        return alexnet()
    if model == 'resnet':
        return resnet(dataset=args.dataset)
    if model == 'wideresnet':
        return WideResNet(args.layers, args.dataset == 'cifar10' and 10 or 100,
                          args.widen_factor, dropRate=args.droprate, gbn=args.gbn)
    if model == 'densenet':
        return densenet()
Exemplo n.º 9
0
    def create_model(ema=False):
        if args.type == 1:
            model = Classifier(num_classes=10)
        else:
            model = WideResNet(num_classes=10)
        model = model.cuda()

        if ema:
            for param in model.parameters():
                param.detach_()

        return model
Exemplo n.º 10
0
def get_model(name, num_classes=10, normalize_input=False):
    name_parts = name.split('-')
    if name_parts[0] == 'wrn':
        depth = int(name_parts[1])
        widen = int(name_parts[2])
        model = WideResNet(
            depth=depth, num_classes=num_classes, widen_factor=widen)
        
    elif name_parts[0] == 'ss':
        model = ShakeNet(dict(depth=int(name_parts[1]),
                              base_channels=int(name_parts[2]),
                              shake_forward=True, shake_backward=True,
                              shake_image=True, input_shape=(1, 3, 32, 32),
                              n_classes=num_classes,
                              ))
    elif name_parts[0] == 'resnet':
        model = ResNet(num_classes=num_classes, depth=int(name_parts[1]))
    else:
        raise ValueError('Could not parse model name %s' % name)

    if normalize_input:
        model = Sequential(NormalizeInput(), model)

    return model
Exemplo n.º 11
0
class FullySupervisedTrainer:
    def __init__(self, batch_size, model_params, n_steps, optimizer, adam, sgd,
                 steps_validation, steps_checkpoint, dataset, save_path):

        self.n_steps = n_steps
        self.start_step = 0
        self.steps_validation = steps_validation
        self.steps_checkpoint = steps_checkpoint
        self.num_labeled = 50000
        self.train_loader, _, self.val_loader, self.test_loader, self.lbl_idx, _, self.val_idx = \
            get_dataloaders_with_index(path='../data',
                                       batch_size=batch_size,
                                       num_labeled=self.num_labeled,
                                       which_dataset=dataset,
                                       validation=False)
        print('Labeled samples: ' + str(len(self.train_loader.sampler)))

        self.batch_size = self.train_loader.batch_size
        self.device = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu")
        print(self.device)

        depth, k, n_out = model_params
        self.model = WideResNet(depth=depth, k=k, n_out=n_out,
                                bias=True).to(self.device)
        self.ema_model = WideResNet(depth=depth, k=k, n_out=n_out,
                                    bias=True).to(self.device)
        for param in self.ema_model.parameters():
            param.detach_()

        if optimizer == 'adam':
            self.lr, self.weight_decay = adam
            self.momentum, self.lr_decay = None, None
            self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
            self.ema_optimizer = WeightEMA(self.model,
                                           self.ema_model,
                                           self.lr,
                                           alpha=0.999)

        else:
            self.lr, self.momentum, self.weight_decay, self.lr_decay = sgd
            self.optimizer = optim.SGD(self.model.parameters(),
                                       lr=self.lr,
                                       momentum=self.momentum,
                                       weight_decay=self.weight_decay,
                                       nesterov=True)
            self.ema_optimizer = None

        self.criterion = nn.CrossEntropyLoss()

        self.train_accuracies, self.train_losses, = [], []
        self.val_accuracies, self.val_losses, = [], []
        self.best_acc = 0

        self.augment = Augment(K=1)

        self.path = save_path
        self.writer = SummaryWriter()

    def train(self):
        iter_train_loader = iter(self.train_loader)

        for step in range(self.n_steps):
            self.model.train()
            # Get next batch of data
            try:
                x_input, x_labels, _ = iter_train_loader.next()
                # Check if batch size has been cropped for last batch
                if x_input.shape[0] < self.batch_size:
                    iter_train_loader = iter(self.train_loader)
                    x_input, x_labels, _ = iter_train_loader.next()
            except:
                iter_train_loader = iter(self.train_loader)
                x_input, x_labels, _ = iter_train_loader.next()

            # Send to GPU
            x_input = x_input.to(self.device)
            x_labels = x_labels.to(self.device)

            # Augment
            x_input = self.augment(x_input)
            x_input = x_input.reshape((-1, 3, 32, 32))

            # Compute X' predictions
            x_output = self.model(x_input)

            # Compute loss
            loss = self.criterion(x_output, x_labels)

            # Step
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            if self.ema_optimizer:
                self.ema_optimizer.step()

            # Decaying learning rate. Used in with SGD Nesterov optimizer
            if not self.ema_optimizer and step in self.lr_decay:
                for g in self.optimizer.param_groups:
                    g['lr'] *= 0.2

            # Evaluate model
            self.model.eval()
            if step > 0 and not step % self.steps_validation:
                val_acc, is_best = self.evaluate_loss_acc(step)
                if is_best:
                    self.save_model(step=step,
                                    path=f'{self.path}/best_checkpoint.pt')

            # Save checkpoint
            if step > 10000 and not step % self.steps_checkpoint:
                self.save_model(step=step,
                                path=f'{self.path}/checkpoint_{step}.pt')

        # --- Training finished ---
        test_val, test_acc = self.evaluate(self.test_loader)
        print("Training done!!\t Test loss: %.3f \t Test accuracy: %.3f" %
              (test_val, test_acc))

        self.writer.flush()

    # --- support functions ---

    def evaluate_loss_acc(self, step):
        val_loss, val_acc = self.evaluate(self.val_loader)
        self.val_losses.append(val_loss)
        self.val_accuracies.append(val_acc)
        train_loss, train_acc = self.evaluate(self.train_loader)
        self.train_losses.append(train_loss)
        self.train_accuracies.append(train_acc)

        is_best = False
        if val_acc > self.best_acc:
            self.best_acc = val_acc
            is_best = True

        print(
            "Step %d.\tLoss train_lbl/valid  %.2f  %.2f\t Accuracy train_lbl/valid  %.2f  %.2f \tBest acc %.2f \t%s"
            % (step, train_loss, val_loss, train_acc, val_acc, self.best_acc,
               time.ctime()))

        self.writer.add_scalar("Loss train_label", train_loss, step)
        self.writer.add_scalar("Loss validation", val_loss, step)
        self.writer.add_scalar("Accuracy train_label", train_acc, step)
        self.writer.add_scalar("Accuracy validation", val_acc, step)
        return val_acc, is_best

    def evaluate(self, dataloader):
        correct, total, loss = 0, 0, 0
        with torch.no_grad():
            for i, data in enumerate(dataloader, 0):
                inputs, labels = data[0].to(self.device), data[1].to(
                    self.device)
                if self.ema_optimizer:
                    outputs = self.ema_model(inputs)
                else:
                    outputs = self.model(inputs)
                loss += self.criterion(outputs, labels).item()

                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
            loss /= dataloader.__len__()

        acc = correct / total * 100
        return loss, acc

    def save_model(self, step=None, path='../models/model.pt'):
        if not step:
            step = self.n_steps  # Training finished

        torch.save(
            {
                'step': step,
                'model_state_dict': self.model.state_dict(),
                'ema_state_dict': self.ema_model.state_dict(),
                'optimizer_state_dict': self.optimizer.state_dict(),
                'loss_train': self.train_losses,
                'loss_val': self.val_losses,
                'acc_train': self.train_accuracies,
                'acc_val': self.val_accuracies,
                'steps': self.n_steps,
                'batch_size': self.batch_size,
                'num_labels': self.num_labeled,
                'lr': self.lr,
                'weight_decay': self.weight_decay,
                'momentum': self.momentum,
                'lr_decay': self.lr_decay,
                'lbl_idx': self.lbl_idx,
                'val_idx': self.val_idx,
            }, path)

    def load_checkpoint(self, model_name):
        saved_model = torch.load(f'../models/{model_name}')
        self.model.load_state_dict(saved_model['model_state_dict'])
        self.ema_model.load_state_dict(saved_model['ema_state_dict'])
        self.optimizer.load_state_dict(saved_model['optimizer_state_dict'])
        self.start_step = saved_model['step']

        self.train_loader, _, self.val_loader, self.test_loader, self.lbl_idx, _, self.val_idx = \
            get_dataloaders_with_index(path='../data',
                                       batch_size=self.batch_size,
                                       num_labeled=self.num_labeled,
                                       which_dataset='cifar10',
                                       lbl_idxs=saved_model['lbl_idx'],
                                       unlbl_idxs=[],
                                       valid_idxs=saved_model['val_idx'])
        print('Model ' + model_name + ' loaded.')
Exemplo n.º 12
0

if __name__ == "__main__":
    from models.wideresnet import WideResNet
    # from models.WideResNet import WideResNet
    import torch

    # tensorboard writer
    tb_writer = SummaryWriter(log_dir=args.tbsw_logdir)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    K = args.K  # steps: 2**20 ideal
    B = args.B  # batch size: 64 ideal
    mu = args.mu
    strides = [1, 1, 2, 2]
    # model = WideResNet(d=28, k=3, n_classes=10, input_features=3, output_features=16, strides=strides)                # prop unsup/sup: 7 ideal
    model = WideResNet(3, 28, 2, 10)

    # Creating Dataset
    labels_per_class = [args.labels_per_class for _ in range(10)]
    dataset_loader = CustomLoader(labels_per_class=labels_per_class,
                                  db_dir=args.root,
                                  db_name=args.use_database,
                                  mode=args.task,
                                  download=args.download)
    dataset_loader.load()

    unlabeled_dataset = DataSet(dataset_loader.get_set(
        dataset_loader.unlabeled_set),
                                batch=B * mu,
                                steps=K)
    labeled_dataset = DataSet(dataset_loader.get_set(
Exemplo n.º 13
0
    elif args.ood_dataset == 'all':
        for i in range(len(ood_dataset.datasets)):
            ood_loader.dataset.datasets[i].imgs = ood_loader.dataset.datasets[
                i].imgs[1000:]
        ood_loader.dataset.cummulative_sizes = ood_loader.dataset.cumsum(
            ood_loader.dataset.datasets)
    else:
        ood_loader.dataset.imgs = ood_loader.dataset.imgs[1000:]
        ood_loader.dataset.__len__ = len(ood_loader.dataset.imgs)

#------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
# Load pre-trained model
#------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

if args.model == 'wideresnet':
    cnn = WideResNet(depth=28, num_classes=num_classes, widen_factor=10).cuda()
elif args.model == 'wideresnet16_8':
    cnn = WideResNet(depth=16, num_classes=num_classes, widen_factor=8).cuda()
elif args.model == 'densenet':
    cnn = DenseNet3(depth=100,
                    num_classes=num_classes,
                    growth_rate=12,
                    reduction=0.5).cuda()
elif args.model == 'vgg13':
    cnn = VGG(vgg_name='VGG13', num_classes=num_classes).cuda()

model_dict = cnn.state_dict()

pretrained_dict = torch.load('checkpoints/' + filename + '.pt')
cnn.load_state_dict(pretrained_dict)
cnn = cnn.cuda()
Exemplo n.º 14
0
    
    tensors, labels = [torch.FloatTensor(np.array(x[0]).transpose(2, 0, 1))/255 for x in ims], [x[1] for x in ims]

    s = torch.stack(tensors)

    return s, torch.LongTensor(labels)

if __name__ == "__main__":
    from models.wideresnet import WideResNet
    # from models.WideResNet import WideResNet
    import torch
    
    device    = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # model = WideResNet(d=28, k=3, n_classes=10, input_features=3, output_features=16, strides=strides)                # prop unsup/sup: 7 ideal
    model     = WideResNet(3, 28, 2, 10)

    
    # Creating Dataset
    labels_per_class  = [100 for _ in range(10)]
    dataset_loader    = CustomLoader(labels_per_class = labels_per_class, db_dir = args.root, db_name = args.use_database, mode = args.task, download = args.download)
    dataset_loader.load()

    test_data = DataSet(dataset_loader.database, batch = 1)

    test_loader = DataLoader(test_data, collate_fn=default_collate_fn, batch_size=64)

    model_directory = os.path.expanduser(args.save_dir)
    
    model = WideResNet(3, 28, 2, 10)
    state_dict = torch.load(os.path.join(model_directory, args.name_model_specs)) 
Exemplo n.º 15
0
        plot_histograms(correct, confidence)

    val_acc = np.mean(correct)
    conf_min = np.min(confidence)
    conf_max = np.max(confidence)
    conf_avg = np.mean(confidence)

    cnn.train()
    return val_acc, conf_min, conf_max, conf_avg


#------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
# Model, Dataset, Optizimer, Scheduler, csv_logger, lambda の設定
#------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
if args.model == 'wideresnet':
    cnn = WideResNet(depth=28, num_classes=num_classes, widen_factor=10).cuda()
elif args.model == 'wideresnet16_8':
    cnn = WideResNet(depth=16, num_classes=num_classes, widen_factor=8).cuda()
elif args.model == 'densenet':
    cnn = DenseNet3(depth=100,
                    num_classes=num_classes,
                    growth_rate=12,
                    reduction=0.5).cuda()
elif args.model == 'vgg13':
    cnn = VGG(vgg_name='VGG13', num_classes=num_classes).cuda()

prediction_criterion = nn.NLLLoss().cuda()

cnn_optimizer = torch.optim.SGD(cnn.parameters(),
                                lr=args.learning_rate,
                                momentum=0.9,
Exemplo n.º 16
0
def get_model(config, num_class=10, bn_types=None, data_parallel=True):
    name = config.model
    print('model name: {}'.format(name))
    print('bn_types: {}'.format(bn_types))
    if name == 'resnet50':
        if bn_types is None:
            model = ResNet(dataset='imagenet',
                           depth=50,
                           num_classes=num_class,
                           bottleneck=True)
        else:
            model = ResNetMultiBN(dataset='imagenet',
                                  depth=50,
                                  num_classes=num_class,
                                  bn_types=bn_types,
                                  bottleneck=True)
    elif name == 'resnet200':
        if bn_types is None:
            model = ResNet(dataset='imagenet',
                           depth=200,
                           num_classes=num_class,
                           bottleneck=True)
        else:
            model = ResNetMultiBN(dataset='imagenet',
                                  depth=200,
                                  num_classes=num_class,
                                  bn_types=bn_types,
                                  bottleneck=True)
    elif name == 'wresnet40_2':
        if bn_types is None:
            model = WideResNet(40, 2, dropout_rate=0.0, num_classes=num_class)
        else:
            raise Exception('unimplemented error')
    elif name == 'wresnet28_10':
        if bn_types is None:
            model = WideResNet(28, 10, dropout_rate=0.0, num_classes=num_class)
        else:
            model = WideResNetMultiBN(28,
                                      10,
                                      dropout_rate=0.0,
                                      num_classes=num_class,
                                      bn_types=bn_types)
    elif name == 'shakeshake26_2x32d':
        if bn_types is None:
            model = ShakeResNet(26, 32, num_class)
        else:
            model = ShakeResNetMultiBN(26, 32, num_class, bn_types)
    elif name == 'shakeshake26_2x64d':
        if bn_types is None:
            model = ShakeResNet(26, 64, num_class)
        else:
            model = ShakeResNetMultiBN(26, 64, num_class, bn_types)
    elif name == 'shakeshake26_2x96d':
        if bn_types is None:
            model = ShakeResNet(26, 96, num_class)
        else:
            model = ShakeResNetMultiBN(26, 96, num_class, bn_types)
    elif name == 'shakeshake26_2x112d':
        if bn_types is None:
            model = ShakeResNet(26, 112, num_class)
        else:
            model = ShakeResNetMultiBN(26, 112, num_class, bn_types)
    elif name == 'shakeshake26_2x96d_next':
        if bn_types is None:
            model = ShakeResNeXt(26, 96, 4, num_class)
        else:
            raise Exception('unimplemented error')

    elif name == 'pyramid':
        if bn_types is None:
            model = PyramidNet('cifar10',
                               depth=config.pyramidnet_depth,
                               alpha=config.pyramidnet_alpha,
                               num_classes=num_class,
                               bottleneck=True)
        else:
            model = PyramidNetMultiBN('cifar10',
                                      depth=config.pyramidnet_depth,
                                      alpha=config.pyramidnet_alpha,
                                      num_classes=num_class,
                                      bottleneck=True,
                                      bn_types=bn_types)
    else:
        raise NameError('no model named, %s' % name)

    if data_parallel:
        model = model.cuda()
        model = DataParallel(model)
    else:
        import horovod.torch as hvd
        device = torch.device('cuda', hvd.local_rank())
        model = model.to(device)
    cudnn.benchmark = True
    return model
    def __init__(self, batch_size, num_lbls, model_params, n_steps, K, lambda_u, optimizer, adam,
                 sgd, steps_validation, steps_checkpoint, dataset, save_path, use_pseudo, tau):

        self.validation_set = False

        self.n_steps = n_steps
        self.start_step = 0
        self.K = K
        self.steps_validation = steps_validation
        self.steps_checkpoint = steps_checkpoint
        self.num_labeled = num_lbls
        self.labeled_loader, self.unlabeled_loader, self.val_loader, self.test_loader, self.lbl_idx, self.unlbl_idx, self.val_idx \
            = get_dataloaders_with_index(path='../data',
                                         batch_size=batch_size,
                                         num_labeled=num_lbls,
                                         which_dataset=dataset,
                                         validation=self.validation_set)
        print('Labeled samples: ' + str(len(self.labeled_loader.sampler)) + '\tUnlabeled samples: ' + str(len(self.unlabeled_loader.sampler)))
        self.targets_list = np.array(self.labeled_loader.dataset.targets)
        self.batch_size = self.labeled_loader.batch_size
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        print(self.device)

        # -- Model --
        depth, k, n_out = model_params
        self.model = WideResNet(depth=depth, k=k, n_out=n_out, bias=False).to(self.device)
        self.ema_model = WideResNet(depth=depth, k=k, n_out=n_out, bias=False).to(self.device)
        for param in self.ema_model.parameters():
            param.detach_()

        if optimizer == 'adam':
            self.lr, self.weight_decay = adam
            self.momentum, self.lr_decay = None, None
            self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
            self.ema_optimizer = WeightEMA(self.model, self.ema_model, self.lr, alpha=0.999)

        else:
            self.lr, self.momentum, self.weight_decay, self.lr_decay = sgd
            self.optimizer = optim.SGD(self.model.parameters(), lr=self.lr, momentum=self.momentum, weight_decay=self.weight_decay, nesterov=True)
            self.ema_optimizer = None

        self.lambda_u_max, self.step_top_up = lambda_u
        self.loss_mixmatch = Loss(self.lambda_u_max, self.step_top_up)
        self.criterion = nn.CrossEntropyLoss()

        self.train_accuracies, self.train_losses, = [], []
        self.val_accuracies, self.val_losses, = [], []
        self.best_acc = 0

        self.mixmatch = MixMatch(self.model, self.batch_size, self.device)

        self.writer = SummaryWriter()
        self.path = save_path

        # -- Pseudo label --
        self.use_pseudo = use_pseudo
        self.steps_pseudo_lbl = 5000
        self.tau = tau  # confidence threshold

        self.min_unlbl_samples = 1000
        # Make a deep copy of original unlabeled loader
        _, self.unlabeled_loader_original, _, _, _, _, _ \
            = get_dataloaders_with_index(path='../data',
                                         batch_size=batch_size,
                                         num_labeled=num_lbls,
                                         which_dataset=dataset,
                                         lbl_idxs=self.lbl_idx,
                                         unlbl_idxs=self.unlbl_idx,
                                         valid_idxs=self.val_idx,
                                         validation=self.validation_set)
def main():
    global args, best_prec1
    args = parser.parse_args()
    # torch.cuda.set_device(args.gpu)
    if args.tensorboard:
        print("Using tensorboard")
        configure("exp/%s" % (args.name))

    # Data loading code

    if args.augment:
        print(
            "Doing image augmentation with\n"
            "Zoom: prob: {zoom_prob}  range: {zoom_range}\n"
            "Stretch: prob: {stretch_prob} range: {stretch_range}\n"
            "Rotation: prob: {rotation_prob} range: {rotation_degree}".format(
                zoom_prob=args.zoom_prob,
                zoom_range=args.zoom_range,
                stretch_prob=args.stretch_prob,
                stretch_range=args.stretch_range,
                rotation_prob=args.rotation_prob,
                rotation_degree=args.rotation_degree))
        transform_train = transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda x: F.pad(
                Variable(x.unsqueeze(0), requires_grad=False, volatile=True),
                (4, 4, 4, 4),
                mode='replicate').data.squeeze()),
            transforms.ToPILImage(),
            transforms.RandomCrop(32),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(prob=args.rotation_prob,
                                      degree=args.rotation_degree),
            transforms.RandomZoom(prob=args.zoom_prob,
                                  zoom_range=args.zoom_range),
            transforms.RandomStretch(prob=args.stretch_prob,
                                     stretch_range=args.stretch_range),
            transforms.ToTensor(),
        ])

    else:
        transform_train = transforms.Compose([
            transforms.ToTensor(),
        ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
    ])

    kwargs = {'num_workers': 1, 'pin_memory': True}
    assert (args.dataset == 'cifar10' or args.dataset == 'cifar100')
    train_loader = torch.utils.data.DataLoader(
        datasets.__dict__[args.dataset.upper()]('../data',
                                                train=True,
                                                download=True,
                                                transform=transform_train),
        batch_size=args.batch_size,
        shuffle=True,
        **kwargs)
    val_loader = torch.utils.data.DataLoader(
        datasets.__dict__[args.dataset.upper()]('../data',
                                                train=False,
                                                transform=transform_test),
        batch_size=args.batch_size,
        shuffle=True,
        **kwargs)

    # create model
    model = WideResNet(args.layers,
                       args.dataset == 'cifar10' and 10 or 100,
                       args.widen_factor,
                       dropRate=args.droprate)

    # get the number of model parameters
    print('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    # for training on multiple GPUs.
    # Use CUDA_VISIBLE_DEVICES=0,1 to specify which GPUs to use
    # model = torch.nn.DataParallel(model).cuda()
    model = model.cuda()

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                nesterov=args.nesterov,
                                weight_decay=args.weight_decay)

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch + 1)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion, epoch)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
            }, is_best)
    print 'Best accuracy: ', best_prec1
Exemplo n.º 19
0
 'AllConv_PGP': lambda args: FuseTrainWrapper(
     AllConv_PGP(args.nclasses)),
 'PreResNet164': lambda args: L.Classifier(
     PreResNet(164, args.nclasses)),
 'PreResNet164_DConv': lambda args: L.Classifier(
     PreResNet_DConv(164, args.nclasses)),
 'PreResNet164_PGP': lambda args: FuseTrainWrapper(
     PreResNet_PGP(164, args.nclasses)),
 'DenseNetBC100': lambda args: L.Classifier(
     DenseNetBC(args.nclasses, (16, 16, 16), 12)),
 'DenseNetBC100_DConv': lambda args: L.Classifier(
     DenseNetBC_DConv(args.nclasses, (16, 16, 16), 12)),
 'DenseNetBC100_PGP': lambda args: FuseTrainWrapper(
     DenseNetBC_PGP(args.nclasses, (16, 16, 16), 12)),
 'WideResNet28-10': lambda args: L.Classifier(
     WideResNet(28, args.nclasses, 10)),
 'WideResNet28-10_DConv': lambda args: L.Classifier(
     WideResNet_DConv(28, args.nclasses, 10)),
 'WideResNet28-10_PGP': lambda args: FuseTrainWrapper(
     WideResNet_PGP(28, args.nclasses, 10)),
 'ResNeXt29_8x64d': lambda args: L.Classifier(
     ResNeXt(29, args.nclasses)),
 'ResNeXt29_8x64d_DConv': lambda args: L.Classifier(
     ResNeXt_DConv(29, args.nclasses)),
 'ResNeXt29_8x64d_PGP': lambda args: FuseTrainWrapper(
     ResNeXt_PGP(29, args.nclasses)),
 'PyramidNetB164': lambda args: L.Classifier(
     PyramidNet(164, args.nclasses)),
 'PyramidNetB164_DConv': lambda args: L.Classifier(
     PyramidNet_DConv(164, args.nclasses)),
 'PyramidNetB164_PGP': lambda args: FuseTrainWrapper(
Exemplo n.º 20
0
                  growth_rate=12,
                  n_class=args.nclasses,
                  in_ch=24,
                  block=3,
                  bottleneck=True,
                  reduction=0.5)),
 'DenseNetBC100':
 lambda args: L.Classifier(DenseNetBC(args.nclasses, (16, 16, 16), 12)),
 'DenseNetBC100_DConv':
 lambda args: L.Classifier(DenseNetBC_DConv(args.nclasses,
                                            (16, 16, 16), 12)),
 'DenseNetBC100_PGP':
 lambda args: FuseTrainWrapper(
     DenseNetBC_PGP(args.nclasses, (16, 16, 16), 12)),
 'WideResNet28-10':
 lambda args: L.Classifier(WideResNet(28, args.nclasses, 10)),
 'WideResNet28-10_DConv':
 lambda args: L.Classifier(WideResNet_DConv(28, args.nclasses, 10)),
 'WideResNet28-10_PGP':
 lambda args: FuseTrainWrapper(WideResNet_PGP(28, args.nclasses, 10)),
 'ResNeXt29_8x64d':
 lambda args: L.Classifier(ResNeXt(29, args.nclasses)),
 'ResNeXt29_8x64d_DConv':
 lambda args: L.Classifier(ResNeXt_DConv(29, args.nclasses)),
 'ResNeXt29_8x64d_PGP':
 lambda args: FuseTrainWrapper(ResNeXt_PGP(29, args.nclasses)),
 'PyramidNetB164':
 lambda args: L.Classifier(PyramidNet(164, args.nclasses)),
 'PyramidNetB164_DConv':
 lambda args: L.Classifier(PyramidNet_DConv(164, args.nclasses)),
 'PyramidNetB164_PGP':
class MixMatchTrainer:

    def __init__(self, batch_size, num_lbls, model_params, n_steps, K, lambda_u, optimizer, adam,
                 sgd, steps_validation, steps_checkpoint, dataset, save_path, use_pseudo, tau):

        self.validation_set = False

        self.n_steps = n_steps
        self.start_step = 0
        self.K = K
        self.steps_validation = steps_validation
        self.steps_checkpoint = steps_checkpoint
        self.num_labeled = num_lbls
        self.labeled_loader, self.unlabeled_loader, self.val_loader, self.test_loader, self.lbl_idx, self.unlbl_idx, self.val_idx \
            = get_dataloaders_with_index(path='../data',
                                         batch_size=batch_size,
                                         num_labeled=num_lbls,
                                         which_dataset=dataset,
                                         validation=self.validation_set)
        print('Labeled samples: ' + str(len(self.labeled_loader.sampler)) + '\tUnlabeled samples: ' + str(len(self.unlabeled_loader.sampler)))
        self.targets_list = np.array(self.labeled_loader.dataset.targets)
        self.batch_size = self.labeled_loader.batch_size
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        print(self.device)

        # -- Model --
        depth, k, n_out = model_params
        self.model = WideResNet(depth=depth, k=k, n_out=n_out, bias=False).to(self.device)
        self.ema_model = WideResNet(depth=depth, k=k, n_out=n_out, bias=False).to(self.device)
        for param in self.ema_model.parameters():
            param.detach_()

        if optimizer == 'adam':
            self.lr, self.weight_decay = adam
            self.momentum, self.lr_decay = None, None
            self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
            self.ema_optimizer = WeightEMA(self.model, self.ema_model, self.lr, alpha=0.999)

        else:
            self.lr, self.momentum, self.weight_decay, self.lr_decay = sgd
            self.optimizer = optim.SGD(self.model.parameters(), lr=self.lr, momentum=self.momentum, weight_decay=self.weight_decay, nesterov=True)
            self.ema_optimizer = None

        self.lambda_u_max, self.step_top_up = lambda_u
        self.loss_mixmatch = Loss(self.lambda_u_max, self.step_top_up)
        self.criterion = nn.CrossEntropyLoss()

        self.train_accuracies, self.train_losses, = [], []
        self.val_accuracies, self.val_losses, = [], []
        self.best_acc = 0

        self.mixmatch = MixMatch(self.model, self.batch_size, self.device)

        self.writer = SummaryWriter()
        self.path = save_path

        # -- Pseudo label --
        self.use_pseudo = use_pseudo
        self.steps_pseudo_lbl = 5000
        self.tau = tau  # confidence threshold

        self.min_unlbl_samples = 1000
        # Make a deep copy of original unlabeled loader
        _, self.unlabeled_loader_original, _, _, _, _, _ \
            = get_dataloaders_with_index(path='../data',
                                         batch_size=batch_size,
                                         num_labeled=num_lbls,
                                         which_dataset=dataset,
                                         lbl_idxs=self.lbl_idx,
                                         unlbl_idxs=self.unlbl_idx,
                                         valid_idxs=self.val_idx,
                                         validation=self.validation_set)

    def train(self):

        iter_labeled_loader = iter(self.labeled_loader)
        iter_unlabeled_loader = iter(self.unlabeled_loader)

        for step in range(self.start_step, self.n_steps):

            # Get next batch of data
            self.model.train()
            try:
                x_imgs, x_labels, _ = iter_labeled_loader.next()
                # Check if batch size has been cropped for last batch
                if x_imgs.shape[0] < self.batch_size:
                    iter_labeled_loader = iter(self.labeled_loader)
                    x_imgs, x_labels, _ = iter_labeled_loader.next()
            except:
                iter_labeled_loader = iter(self.labeled_loader)
                x_imgs, x_labels, _ = iter_labeled_loader.next()

            try:
                u_imgs, _, _ = iter_unlabeled_loader.next()
                if u_imgs.shape[0] < self.batch_size:
                    iter_unlabeled_loader = iter(self.unlabeled_loader)
                    u_imgs, _, _ = iter_unlabeled_loader.next()
            except:
                iter_unlabeled_loader = iter(self.unlabeled_loader)
                u_imgs, _, _ = iter_unlabeled_loader.next()

            # Send to GPU
            x_imgs = x_imgs.to(self.device)
            x_labels = x_labels.to(self.device)
            u_imgs = u_imgs.to(self.device)

            # MixMatch algorithm
            x, u = self.mixmatch.run(x_imgs, x_labels, u_imgs)
            x_input, x_targets = x
            u_input, u_targets = u
            u_targets.detach_()  # stop gradients from propagation to label guessing

            # Compute X' predictions
            x_output = self.model(x_input)

            # Compute U' predictions. Separate in batches
            u_batch_outs = []
            for k in range(self.K):
                u_batch = u_input[k * self.batch_size:(k + 1) * self.batch_size]
                u_batch_outs.append(self.model(u_batch))
            u_outputs = torch.cat(u_batch_outs, dim=0)

            # Compute loss
            loss = self.loss_mixmatch(x_output, x_targets, u_outputs, u_targets, step)

            # Step
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            if self.ema_optimizer:
                self.ema_optimizer.step()

            # Decaying learning rate. Used in with SGD Nesterov optimizer
            if not self.ema_optimizer and step in self.lr_decay:
                for g in self.optimizer.param_groups:
                    g['lr'] *= 0.2

            # Evaluate model
            self.model.eval()
            if step > 0 and not step % self.steps_validation:
                val_acc, is_best = self.evaluate_loss_acc(step)
                if is_best and step > 10000:
                    self.save_model(step=step, path=f'{self.path}/best_checkpoint.pt')

            # Save checkpoint
            if step > 0 and not step % self.steps_checkpoint:
                self.save_model(step=step, path=f'{self.path}/checkpoint_{step}.pt')

            # Generate Pseudo-labels
            if self.use_pseudo and step >= 50000 and not step % self.steps_pseudo_lbl:
                # matrix columns: [index, confidence, pseudo_label, true_label, is_ground_truth]
                matrix = self.get_pseudo_labels()
                self.print_threshold_comparison(matrix)

                # Generate pseudo set based on threshold (same for all classes)
                if self.tau != -1:
                    matrix = self.generate_pseudo_set(matrix)

                # Generate pseudo set balanced (top 90% guesses of each class)
                else:
                    matrix = self.generate_pseudo_set_balanced(matrix)

                iter_labeled_loader = iter(self.labeled_loader)
                iter_unlabeled_loader = iter(self.unlabeled_loader)

                # Save
                torch.save(matrix, f'{self.path}/pseudo_matrix_balanced_{step}.pt')

        # --- Training finished ---
        test_val, test_acc = self.evaluate(self.test_loader)
        print("Training done!!\t Test loss: %.3f \t Test accuracy: %.3f" % (test_val, test_acc))

        self.writer.flush()

    # --- support functions ---
    def evaluate_loss_acc(self, step):
        val_loss, val_acc = self.evaluate(self.val_loader)
        self.val_losses.append(val_loss)
        self.val_accuracies.append(val_acc)
        train_loss, train_acc = self.evaluate(self.labeled_loader)
        self.train_losses.append(train_loss)
        self.train_accuracies.append(train_acc)

        is_best = False
        if val_acc > self.best_acc:
            self.best_acc = val_acc
            is_best = True

        print("Step %d.\tLoss train_lbl/valid  %.2f  %.2f\t Accuracy train_lbl/valid  %.2f  %.2f \tBest acc %.2f \t%s" %
              (step, train_loss, val_loss, train_acc, val_acc, self.best_acc, time.ctime()))

        self.writer.add_scalar("Loss train_label", train_loss, step)
        self.writer.add_scalar("Loss validation", val_loss, step)
        self.writer.add_scalar("Accuracy train_label", train_acc, step)
        self.writer.add_scalar("Accuracy validation", val_acc, step)
        return val_acc, is_best

    def evaluate(self, dataloader):
        correct, total, loss = 0, 0, 0
        with torch.no_grad():
            for i, data in enumerate(dataloader, 0):
                inputs, labels = data[0].to(self.device), data[1].to(self.device)
                if self.ema_optimizer:
                    outputs = self.ema_model(inputs)
                else:
                    outputs = self.model(inputs)
                loss += self.criterion(outputs, labels).item()

                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
            loss /= dataloader.__len__()

        acc = correct / total * 100
        return loss, acc


    def get_losses(self):
        return self.loss_mixmatch.loss_list, self.loss_mixmatch.lx_list, self.loss_mixmatch.lu_list, self.loss_mixmatch.lu_weighted_list

    def save_model(self, step=None, path=f'../model.pt'):
        loss_list, lx, lu, lu_weighted = self.get_losses()
        if not step:
            step = self.n_steps     # Training finished

        torch.save({
            'step': step,
            'model_state_dict': self.model.state_dict(),
            'ema_state_dict': self.ema_model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'loss_train': self.train_losses,
            'loss_val': self.val_losses,
            'acc_train': self.train_accuracies,
            'acc_val': self.val_accuracies,
            'loss_batch': loss_list,
            'lx': lx,
            'lu': lu,
            'lu_weighted': lu_weighted,
            'steps': self.n_steps,
            'batch_size': self.batch_size,
            'num_labels': self.num_labeled,
            'lambda_u_max': self.lambda_u_max,
            'step_top_up': self.step_top_up,
            'lr': self.lr,
            'weight_decay': self.weight_decay,
            'momentum': self.momentum,
            'lr_decay': self.lr_decay,
            'lbl_idx': self.lbl_idx,
            'unlbl_idx': self.unlbl_idx,
            'val_idx': self.val_idx,
        }, path)

    def load_checkpoint(self, path_checkpoint):
        saved_model = torch.load(path_checkpoint)
        self.model.load_state_dict(saved_model['model_state_dict'])
        self.ema_model.load_state_dict(saved_model['ema_state_dict'])
        self.optimizer.load_state_dict(saved_model['optimizer_state_dict'])
        self.start_step = saved_model['step']
        self.train_losses = saved_model['loss_train']
        self.val_losses = saved_model['loss_val']
        self.train_accuracies = saved_model['acc_train']
        self.val_accuracies = saved_model['acc_val']
        self.batch_size = saved_model['batch_size']
        self.num_labeled = saved_model['num_labels']
        self.labeled_loader, self.unlabeled_loader, self.val_loader, self.test_loader, self.lbl_idx, self.unlbl_idx, self.val_idx = \
            get_dataloaders_with_index(path='../data',
                                       batch_size=self.batch_size,
                                       num_labeled=self.num_labeled,
                                       which_dataset='cifar10',
                                       lbl_idxs=saved_model['lbl_idx'],
                                       unlbl_idxs=saved_model['unlbl_idx'],
                                       valid_idxs=saved_model['val_idx'],
                                       validation=self.validation_set)
        self.unlabeled_loader_original = self.unlabeled_loader
        print('Model ' + path_checkpoint + ' loaded.')

    def get_pseudo_labels(self):
        matrix = torch.tensor([], device=self.device)

        # Iterate through unlabeled loader
        for batch_idx, (data, target, idx) in enumerate(self.unlabeled_loader_original):
            with torch.no_grad():
                # Get predictions for unlabeled samples
                out = self.model(data.to(self.device))
                p_out = torch.softmax(out, dim=1)  # turn into probability distribution
                confidence, pseudo_lbl = torch.max(p_out, dim=1)
                pseudo_lbl_batch = torch.vstack((idx.to(self.device), confidence, pseudo_lbl)).T
                # Append to matrix
                matrix = torch.cat((matrix, pseudo_lbl_batch), dim=0)  # (n_unlabeled, 3)

        n_unlabeled = matrix.shape[0]
        indices = matrix[:, 0].cpu().numpy().astype(int)
        ground_truth = self.targets_list[indices]
        matrix = torch.vstack((matrix.T, torch.tensor(ground_truth, device=self.device))).T   # (n_unlabeled, 4)
        matrix = torch.vstack((matrix.T, torch.zeros(n_unlabeled, device=self.device))).T   # (n_unlabeled, 5)

        # matrix columns: [index, confidence, pseudo_label, true_label, is_ground_truth]

        # Check if pseudo label is ground truth
        for i in range(n_unlabeled):
            if matrix[i, 2] == matrix[i, 3]:
                matrix[i, 4] = 1
        return matrix

    def generate_pseudo_set(self, matrix):

        unlbl_mask1 = (matrix[:, 1] < self.tau)
        # unlbl_mask2 = (matrix[:, 1] >= 0.99)
        pseudo_mask = (matrix[:, 1] >= self.tau)
        # pseudo_mask = (matrix[:, 1] >= self.tau) & (matrix[:, 1] < 0.99)

        # unlbl_indices = torch.cat((matrix[unlbl_mask1, 0], matrix[unlbl_mask2, 0]))
        unlbl_indices = matrix[unlbl_mask1, 0]
        matrix = matrix[pseudo_mask, :]
        indices = matrix[:, 0]

        new_lbl_idx = np.int_(torch.cat((torch.tensor(self.lbl_idx, device=self.device), indices)).tolist())
        new_unlbl_idx = np.int_(unlbl_indices.tolist())
        self.labeled_loader, self.unlabeled_loader, self.val_loader, self.test_loader, _, _, new_val_idx = \
            get_dataloaders_with_index(path='../data', batch_size=self.batch_size, num_labeled=self.num_labeled,
                                       which_dataset='cifar10', lbl_idxs=new_lbl_idx, unlbl_idxs=new_unlbl_idx,
                                       valid_idxs=self.val_idx, validation=self.validation_set)

        assert np.allclose(self.val_idx, new_val_idx), 'error'
        assert (len(self.labeled_loader.sampler) + len(self.unlabeled_loader.sampler) == 50000), 'error'

        # Change real labels for pseudo labels
        for i in range(matrix.shape[0]):
            index = int(matrix[i, 0].item())
            assert int(matrix[i, 3]) == self.labeled_loader.dataset.targets[index]
            pseudo_labels = int(matrix[i, 2].item())
            self.labeled_loader.dataset.targets[index] = pseudo_labels

        correct = torch.sum(matrix[:, 4]).item()
        pseudo_acc = correct / matrix.shape[0] * 100 if matrix.shape[0] > 0 else 0
        print('Generated labels: %d\t Correct: %d\t Accuracy: %.2f' % (matrix.shape[0], correct, pseudo_acc))
        print('Training with Labeled / Unlabeled / Validation samples\t %d %d %d' % (len(new_lbl_idx),
                                                                                     len(new_unlbl_idx),
                                                                                     len(self.val_idx)))

        return matrix

    def generate_pseudo_set_balanced(self, matrix_all):

        unlbl_indices = torch.tensor([], device=self.device)

        # Get top 10% confident guesses for each class
        matrix = torch.tensor([], device=self.device)
        for i in range(10):
            matrix_label = matrix_all[matrix_all[:, 2] == i, :]
            threshold = torch.quantile(matrix_label[:, 1], 0.9)     # returns prob in the percentile 90
            unlbl_idxs = matrix_label[matrix_label[:, 1] < threshold, 0]
            matrix_label = matrix_label[matrix_label[:, 1] >= threshold, :]
            matrix = torch.cat((matrix, matrix_label), dim=0)
            unlbl_indices = torch.cat((unlbl_indices, unlbl_idxs))
        indices = matrix[:, 0]

        new_lbl_idx = np.int_(torch.cat((torch.tensor(self.lbl_idx, device=self.device), indices)).tolist())
        new_unlbl_idx = np.int_(unlbl_indices.tolist())
        self.labeled_loader, self.unlabeled_loader, self.val_loader, self.test_loader, _, _, new_val_idx = \
            get_dataloaders_with_index(path='../data', batch_size=self.batch_size, num_labeled=self.num_labeled,
                                       which_dataset='cifar10', lbl_idxs=new_lbl_idx, unlbl_idxs=new_unlbl_idx,
                                       valid_idxs=self.val_idx, validation=self.validation_set)

        assert np.allclose(self.val_idx, new_val_idx), 'error'
        assert (len(self.labeled_loader.sampler) + len(self.unlabeled_loader.sampler) == 50000), 'error'

        # Change real labels for pseudo labels
        for i in range(matrix.shape[0]):
            index = int(matrix[i, 0].item())
            assert int(matrix[i, 3]) == self.labeled_loader.dataset.targets[index]
            pseudo_labels = int(matrix[i, 2].item())
            self.labeled_loader.dataset.targets[index] = pseudo_labels

        correct = torch.sum(matrix[:, 4]).item()
        pseudo_acc = correct / matrix.shape[0] * 100 if matrix.shape[0] > 0 else 0
        print('Generated labels: %d\t Correct: %d\t Accuracy: %.2f' % (matrix.shape[0], correct, pseudo_acc))
        print('Training with Labeled / Unlabeled / Validation samples\t %d %d %d' % (len(new_lbl_idx),
                                                                                     len(new_unlbl_idx),
                                                                                     len(self.val_idx)))
        return matrix

    def print_threshold_comparison(self, matrix):
        m2 = matrix[matrix[:, 1] >= 0.9, :]
        for i, tau in enumerate([0.9, 0.95, 0.97, 0.99, 0.999]):
            pseudo_labels = m2[m2[:, 1] >= tau, :]
            total = pseudo_labels.shape[0]
            correct = torch.sum(pseudo_labels[:, 4]).item()
            print('Confidence threshold %.3f\t Generated / Correct / Precision\t %d\t%d\t%.2f '
                  % (tau, total, correct, correct / (total + np.finfo(float).eps) * 100))
Exemplo n.º 22
0
        images = Variable(images, volatile=True)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum()

    print('Accuracy of the network on the %d test images: %d %%' %
          (total, 100 * correct / total))


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='PyTorch FeedForward Example')
    parser.add_argument('--epochs',
                        type=int,
                        default=3,
                        help='number of epochs to train')
    parser.add_argument('--lr', type=float, default=0.01, help='learning rate')
    parser.add_argument('--batch_size',
                        type=int,
                        default=128,
                        help='batch size')
    args = parser.parse_args()

    train_loader, test_loader = data_loader()
    model = WideResNet(int(28), 10)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)
    train(args.epochs)
    test()
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--stage', default='train', type=str)
    parser.add_argument('--dataset', default='imagenet', type=str)
    parser.add_argument('--lr', default=0.0012, type=float)
    parser.add_argument('--batch_size', default=128, type=int)
    parser.add_argument('--gpus', default='0,1,2,3', type=str)
    parser.add_argument('--weight_decay', default=1e-5, type=float)
    parser.add_argument('--max_epoch', default=30, type=int)
    parser.add_argument('--lr_decay_steps', default='15,20,25', type=str)
    parser.add_argument('--exp', default='', type=str)
    parser.add_argument('--list', default='', type=str)
    parser.add_argument('--resume_path', default='', type=str)
    parser.add_argument('--pretrain_path', default='', type=str)
    parser.add_argument('--n_workers', default=32, type=int)

    parser.add_argument('--network', default='resnet50', type=str)

    global args
    args = parser.parse_args()

    if not os.path.exists(args.exp):
        os.makedirs(args.exp)
    if not os.path.exists(os.path.join(args.exp, 'runs')):
        os.makedirs(os.path.join(args.exp, 'runs'))
    if not os.path.exists(os.path.join(args.exp, 'models')):
        os.makedirs(os.path.join(args.exp, 'models'))
    if not os.path.exists(os.path.join(args.exp, 'logs')):
        os.makedirs(os.path.join(args.exp, 'logs'))

    # logger initialize
    logger = getLogger(args.exp)

    device_ids = list(map(lambda x: int(x), args.gpus.split(',')))
    device = torch.device('cuda: 0')

    train_loader, val_loader = cifar.get_semi_dataloader(
        args) if args.dataset.startswith(
            'cifar') else imagenet.get_semi_dataloader(args)

    # create model
    if args.network == 'alexnet':
        network = AlexNet(128)
    elif args.network == 'alexnet_cifar':
        network = AlexNet_cifar(128)
    elif args.network == 'resnet18_cifar':
        network = ResNet18_cifar()
    elif args.network == 'resnet50_cifar':
        network = ResNet50_cifar()
    elif args.network == 'wide_resnet28':
        network = WideResNet(28, args.dataset == 'cifar10' and 10 or 100, 2)
    elif args.network == 'resnet18':
        network = resnet18()
    elif args.network == 'resnet50':
        network = resnet50()
    network = nn.DataParallel(network, device_ids=device_ids)
    network.to(device)

    classifier = nn.Linear(2048, 1000).to(device)
    # create optimizer

    parameters = network.parameters()
    optimizer = torch.optim.SGD(
        parameters,
        lr=args.lr,
        momentum=0.9,
        weight_decay=args.weight_decay,
    )

    cls_optimizer = torch.optim.SGD(
        classifier.parameters(),
        lr=args.lr * 50,
        momentum=0.9,
        weight_decay=args.weight_decay,
    )

    cudnn.benchmark = True

    # create memory_bank
    global writer
    writer = SummaryWriter(comment='SemiSupervised',
                           logdir=os.path.join(args.exp, 'runs'))

    # create criterion
    criterion = nn.CrossEntropyLoss()

    logging.info(beautify(args))
    start_epoch = 0
    if args.pretrain_path != '' and args.pretrain_path != 'none':
        logging.info('loading pretrained file from {}'.format(
            args.pretrain_path))
        checkpoint = torch.load(args.pretrain_path)
        state_dict = checkpoint['state_dict']
        valid_state_dict = {
            k: v
            for k, v in state_dict.items()
            if k in network.state_dict() and 'fc.' not in k
        }
        for k, v in network.state_dict().items():
            if k not in valid_state_dict:
                logging.info('{}: Random Init'.format(k))
                valid_state_dict[k] = v
        # logging.info(valid_state_dict.keys())
        network.load_state_dict(valid_state_dict)
    else:
        logging.info('Training SemiSupervised Learning From Scratch')

    logging.info('start training')
    best_acc = 0.0
    try:
        for i_epoch in range(start_epoch, args.max_epoch):
            train(i_epoch, network, classifier, criterion, optimizer,
                  cls_optimizer, train_loader, device)

            checkpoint = {
                'epoch': i_epoch + 1,
                'state_dict': network.state_dict(),
                'optimizer': optimizer.state_dict(),
            }
            torch.save(checkpoint,
                       os.path.join(args.exp, 'models', 'checkpoint.pth'))
            adjust_learning_rate(args.lr_decay_steps, optimizer, i_epoch)
            if i_epoch % 2 == 0:
                acc1, acc5 = validate(i_epoch, network, classifier, val_loader,
                                      device)
                if acc1 >= best_acc:
                    best_acc = acc1
                    torch.save(checkpoint,
                               os.path.join(args.exp, 'models', 'best.pth'))
                writer.add_scalar('acc1', acc1, i_epoch + 1)
                writer.add_scalar('acc5', acc5, i_epoch + 1)

            if i_epoch in [30, 60, 120, 160, 200]:
                torch.save(
                    checkpoint,
                    os.path.join(args.exp, 'models',
                                 '{}.pth'.format(i_epoch + 1)))

            logging.info(
                colorful('[Epoch: {}] val acc: {:.4f}/{:.4f}'.format(
                    i_epoch, acc1, acc5)))
            logging.info(
                colorful('[Epoch: {}] best acc: {:.4f}'.format(
                    i_epoch, best_acc)))

            with torch.no_grad():
                for name, param in network.named_parameters():
                    if 'bn' not in name:
                        writer.add_histogram(name, param, i_epoch)

            # cluster
    except KeyboardInterrupt as e:
        logging.info('KeyboardInterrupt at {} Epochs'.format(i_epoch))
        exit()
Exemplo n.º 24
0
def load(config):
    if config.model == 'wideresnet':
        model = WideResNet(num_classes=config.dataset_classes)
        ema_model = WideResNet(num_classes=config.dataset_classes)
    else:
        model = CNN13(num_classes=config.dataset_classes)
        ema_model = CNN13(num_classes=config.dataset_classes)

    if config.semi_supervised == 'mix_match':
        semi_supervised = MixMatch(config)
        semi_supervised_loss = mix_match_loss
    elif config.semi_supervised == 'pseudo_label':
        semi_supervised = PseudoLabel(config)
        semi_supervised_loss = pseudo_label_loss

    model.to(config.device)
    ema_model.to(config.device)

    torch.backends.cudnn.benchmark = True

    optimizer = Adam(model.parameters(), lr=config.learning_rate)
    ema_optimizer = WeightEMA(model, ema_model, alpha=config.ema_decay)

    if config.resume:
        checkpoint = torch.load(config.checkpoint_path, map_location=config.device)

        model.load_state_dict(checkpoint['model_state'])
        ema_model.load_state_dict(checkpoint['ema_model_state'])
        optimizer.load_state_dict(checkpoint['optimizer_state'])

        # optimizer state should be moved to corresponding device
        for optimizer_state in optimizer.state.values():
            for k, v in optimizer_state.items():
                if isinstance(v, torch.Tensor):
                    optimizer_state[k] = v.to(config.device)

    return model, ema_model, optimizer, ema_optimizer, semi_supervised, semi_supervised_loss
Exemplo n.º 25
0
        '--path_to_save',
        type=str,
        default='./results/',
        help='directory to save the results, must already exist')

    hps = parser.parse_args()
    hps.n_labels = 43 if hps.dataset == 'gts' else 10
    hps.las = True if hps.las in ['True', 'true', '1'] else False
    if hps.eps == -1: hps.eps = 0.3 if hps.dataset == 'mnist' else 0.0314
    assert hps.p == 'linf', 'Lp-norm not supported'

    ### load models and datasets one can get at https://github.com/yaodongyu/TRADES
    if hps.dataset == 'cifar10':
        from models.wideresnet import WideResNet
        device = torch.device("cuda")
        model = WideResNet().to(device)
        model.load_state_dict(torch.load('./checkpoints/model_cifar_wrn.pt'))
        model.eval()

        X_data = np.load('./data_attack/cifar10_X.npy')
        Y_data = np.load('./data_attack/cifar10_Y.npy')
        X_data = np.transpose(X_data, (0, 3, 1, 2))

    elif hps.dataset == 'mnist':
        from models.small_cnn import SmallCNN
        device = torch.device("cuda")
        model = SmallCNN().to(device)
        model.load_state_dict(
            torch.load('./checkpoints/model_mnist_smallcnn.pt'))
        model.eval()