Esempio n. 1
0
def test_dataparallel_params(model, params):
    device = torch.device('cuda:0')
    model = DataParallel(model)
    model.to(device=device)

    inputs = torch.rand(5, 2).to(device=device)
    outputs = model(inputs, params=params)

    assert outputs.shape == (5, 1)
    assert outputs.device == device
Esempio n. 2
0
    def __init__(self, model_dict, dataset, device):
        super().__init__()

        self.dataset = dataset

        self.model_dict = model_dict
        if self.model_dict['name'] == 'resnet18':
            if self.model_dict['pretrained']:
                self.net = models.resnet18(pretrained=True)
                self.net.fc = nn.Linear(512, self.dataset.n_classes)
            else:
                self.net = models.resnet18(num_classes=self.dataset.n_classes)

        elif self.model_dict['name'] == 'resnet18_meta':
            if self.model_dict.get('pretrained', True):
                self.net = resnet_meta.resnet18(pretrained=True)
                self.net.fc = MetaLinear(512, self.dataset.n_classes)
            else:
                self.net = resnet_meta.resnet18(
                    num_classes=self.dataset.n_classes)
        elif self.model_dict['name'] == 'resnet18_meta_2':
            self.net = resnet_meta_2.ResNet18(nc=3,
                                              nclasses=self.dataset.n_classes)

        elif self.model_dict['name'] == 'resnet18_meta_old':
            self.net = resnet_meta_old.ResNet18(
                nc=3, nclasses=self.dataset.n_classes)

        else:
            raise ValueError('network %s does not exist' % model_dict['name'])

        if (device.type == 'cuda'):
            self.net = DataParallel(self.net)
        self.net.to(device)
        # set optimizer
        self.opt_dict = model_dict['opt']
        self.lr_init = self.opt_dict['lr']
        if self.model_dict['opt']['name'] == 'sps':
            n_batches_per_epoch = 120
            self.opt = sps.Sps(self.net.parameters(),
                               n_batches_per_epoch=n_batches_per_epoch,
                               c=0.5,
                               adapt_flag='smooth_iter',
                               eps=0,
                               eta_max=None)
        else:
            self.opt = optim.SGD(self.net.parameters(),
                                 lr=self.opt_dict['lr'],
                                 momentum=self.opt_dict['momentum'],
                                 weight_decay=self.opt_dict['weight_decay'])

        # variables
        self.device = device
Esempio n. 3
0
def test_dataparallel_params_maml(model):
    device = torch.device('cuda:0')
    model = DataParallel(model)
    model.to(device=device)

    train_inputs = torch.rand(5, 2).to(device=device)
    train_outputs = model(train_inputs)

    inner_loss = train_outputs.sum()  # Dummy loss
    params = gradient_update_parameters(model, inner_loss)

    test_inputs = torch.rand(5, 2).to(device=device)
    test_outputs = model(test_inputs, params=params)

    assert test_outputs.shape == (5, 1)
    assert test_outputs.device == device

    outer_loss = test_outputs.sum()  # Dummy loss
    outer_loss.backward()
Esempio n. 4
0
    def __init__(self, model_dict, dataset, device):
        super().__init__()

        if model_dict['name'] == 'stn':
            self.net = stn.STN(isize=dataset.image_size,
                               n_channels=dataset.nc,
                               n_filters=64,
                               nz=100,
                               datasetmean=dataset.mean,
                               datasetstd=dataset.std)

        elif model_dict['name'] == 'small_affine':
            self.net = small_affine.smallAffine(
                nz=6,
                transformation=model_dict['transform'],
                datasetmean=dataset.mean,
                datasetstd=dataset.std)

        elif model_dict['name'] == 'affine_color':
            self.net = affine_color.affineColor(nz=10,
                                                datasetmean=dataset.mean,
                                                datasetstd=dataset.std)

        else:
            raise ValueError('network %s does not exist' % model_dict['name'])

        if (device.type == 'cuda'):
            self.net = DataParallel(self.net)

        self.net.to(device)

        self.device = device
        self.factor = model_dict['factor']
        self.name = model_dict['name']

        if model_dict['name'] != 'random_augmenter':
            self.opt_dict = model_dict['opt']
            self.lr_init = self.opt_dict['lr']
            self.opt = optim.SGD(self.net.parameters(),
                                 lr=self.opt_dict['lr'],
                                 momentum=self.opt_dict['momentum'],
                                 weight_decay=self.opt_dict['weight_decay'])
Esempio n. 5
0
        torch.backends.cudnn.benchmark = True
    elif args.use_cuda:
        raise RuntimeError('You are using GPU mode, but GPUs are not available!')
    
    # construct model and optimizer
    args.image_len = 28 if args.train_data == 'omniglot' else 84
    args.out_channels, _ = get_outputs_c_h(args.backbone, args.image_len)
    
    model = ConvolutionalNeuralNetwork(args.backbone, args.out_channels, args.num_ways)
    teacher_model = ConvolutionalNeuralNetwork(args.teacher_backbone, args.out_channels, args.num_ways)

    if args.use_cuda:
        os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
        num_gpus = torch.cuda.device_count()
        if args.multi_gpu:
            model = DataParallel(model)
            teacher_model = DataParallel(teacher_model)

        model = model.cuda() 
        teacher_model = teacher_model.cuda()       
       
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=0.0005)

    # download teacher trained model
    if args.teacher_resume and args.teacher_resume_folder is not None:
        teacher_model_path = os.path.join(args.teacher_resume_folder, ('_'.join(['teacher', args.train_data, args.test_data, args.teacher_backbone, 'max_acc']) + '.pt'))
        teacher_state = torch.load(teacher_model_path)
        if args.multi_gpu:
            teacher_model.module.load_state_dict(teacher_state)
        else:
            teacher_model.load_state_dict(teacher_state)
Esempio n. 6
0
class Classifier(nn.Module):
    def __init__(self, model_dict, dataset, device):
        super().__init__()

        self.dataset = dataset

        self.model_dict = model_dict
        if self.model_dict['name'] == 'resnet18':
            if self.model_dict['pretrained']:
                self.net = models.resnet18(pretrained=True)
                self.net.fc = nn.Linear(512, self.dataset.n_classes)
            else:
                self.net = models.resnet18(num_classes=self.dataset.n_classes)

        elif self.model_dict['name'] == 'resnet18_meta':
            if self.model_dict.get('pretrained', True):
                self.net = resnet_meta.resnet18(pretrained=True)
                self.net.fc = MetaLinear(512, self.dataset.n_classes)
            else:
                self.net = resnet_meta.resnet18(
                    num_classes=self.dataset.n_classes)
        elif self.model_dict['name'] == 'resnet18_meta_2':
            self.net = resnet_meta_2.ResNet18(nc=3,
                                              nclasses=self.dataset.n_classes)

        elif self.model_dict['name'] == 'resnet18_meta_old':
            self.net = resnet_meta_old.ResNet18(
                nc=3, nclasses=self.dataset.n_classes)

        else:
            raise ValueError('network %s does not exist' % model_dict['name'])

        if (device.type == 'cuda'):
            self.net = DataParallel(self.net)
        self.net.to(device)
        # set optimizer
        self.opt_dict = model_dict['opt']
        self.lr_init = self.opt_dict['lr']
        if self.model_dict['opt']['name'] == 'sps':
            n_batches_per_epoch = 120
            self.opt = sps.Sps(self.net.parameters(),
                               n_batches_per_epoch=n_batches_per_epoch,
                               c=0.5,
                               adapt_flag='smooth_iter',
                               eps=0,
                               eta_max=None)
        else:
            self.opt = optim.SGD(self.net.parameters(),
                                 lr=self.opt_dict['lr'],
                                 momentum=self.opt_dict['momentum'],
                                 weight_decay=self.opt_dict['weight_decay'])

        # variables
        self.device = device

    def get_state_dict(self):
        state_dict = {
            'net': self.net.state_dict(),
            'opt': self.opt.state_dict(),
        }

        return state_dict

    def load_state_dict(self, state_dict):
        self.net.load_state_dict(state_dict['net'])
        self.opt.load_state_dict(state_dict['opt'])

    def on_trainloader_start(self, epoch):
        if self.opt_dict['sched']:
            ut.adjust_learning_rate_netC(self.opt, epoch, self.lr_init,
                                         self.model_dict['name'],
                                         self.dataset.name)

    def train_on_batch(self, batch):
        images, labels = batch['images'].to(
            self.device,
            non_blocking=True), batch['labels'].to(self.device,
                                                   non_blocking=True)

        logits = self.net(images)
        loss = F.cross_entropy(logits, labels, reduction="mean")

        self.opt.zero_grad()
        loss.backward()

        if self.opt_dict['name'] == 'sps':
            self.opt.step(loss=loss)
        else:
            self.opt.step()
        # print(ut.compute_parameter_sum(self))

        return loss.item()
Esempio n. 7
0
class Augmenter(nn.Module):
    def __init__(self, model_dict, dataset, device):
        super().__init__()

        if model_dict['name'] == 'stn':
            self.net = stn.STN(isize=dataset.image_size,
                               n_channels=dataset.nc,
                               n_filters=64,
                               nz=100,
                               datasetmean=dataset.mean,
                               datasetstd=dataset.std)

        elif model_dict['name'] == 'small_affine':
            self.net = small_affine.smallAffine(
                nz=6,
                transformation=model_dict['transform'],
                datasetmean=dataset.mean,
                datasetstd=dataset.std)

        elif model_dict['name'] == 'affine_color':
            self.net = affine_color.affineColor(nz=10,
                                                datasetmean=dataset.mean,
                                                datasetstd=dataset.std)

        else:
            raise ValueError('network %s does not exist' % model_dict['name'])

        if (device.type == 'cuda'):
            self.net = DataParallel(self.net)

        self.net.to(device)

        self.device = device
        self.factor = model_dict['factor']
        self.name = model_dict['name']

        if model_dict['name'] != 'random_augmenter':
            self.opt_dict = model_dict['opt']
            self.lr_init = self.opt_dict['lr']
            self.opt = optim.SGD(self.net.parameters(),
                                 lr=self.opt_dict['lr'],
                                 momentum=self.opt_dict['momentum'],
                                 weight_decay=self.opt_dict['weight_decay'])

    def cycle(self, iterable):
        iterator = iter(iterable)
        while True:
            try:
                yield next(iterator)
            except StopIteration:
                iterator = iter(iterable)

    def get_state_dict(self):
        state_dict = {}
        if hasattr(self, 'opt'):
            state_dict['net'] = self.net.state_dict()
            state_dict['opt'] = self.opt.state_dict()

        return state_dict

    def load_state_dict(self, state_dict):
        if hasattr(self, 'opt'):
            self.net.load_state_dict(state_dict['net'])
            self.opt.load_state_dict(state_dict['opt'])

    def apply_augmentation(self, images, labels):
        # apply augmentation to the given images
        factor = self.factor
        if factor > 1:
            labels = labels.repeat(factor)
            images = images.repeat(factor, 1, 1, 1)

        with torch.autograd.set_detect_anomaly(True):
            augimages, transformations = self.net(images)

        return augimages, labels, transformations

    def on_trainloader_start(self, epoch, valloader, netC):
        # Get slope
        if hasattr(self, 'slope_annealing'):
            self.slope = get_slope(self.slope_annealing, epoch)
        # Update optimizer
        if self.opt_dict['sched']:
            ut.adjust_learning_rate_netA(self.optimizerA, epoch, self.lrA_init)

        # initialize momentums
        if netC.opt.defaults['momentum']:
            self.moms = OrderedDict()
            for (name, p) in netC.net.named_parameters():
                self.moms[name] = torch.zeros(p.shape).to(self.device)

        self.epoch = epoch
        # Cycle through val_loader
        self.val_gen = self.cycle(valloader)

    def train_on_batch(self, batch, netC):
        self.train()
        images, labels = batch['images'].to(
            self.device,
            non_blocking=True), batch['labels'].to(self.device,
                                                   non_blocking=True)
        images, labels, transformations = self.apply_augmentation(
            images, labels)

        # Use classifier
        logits = netC.net(images)
        loss_clf = F.cross_entropy(logits, labels, reduction="mean")

        netC.opt.zero_grad()

        if self.name in ['random_augmenter']:
            # Update the classifier only
            loss_clf.backward()
            netC.opt.step()

            return loss_clf

        elif self.name in ['stn']:
            # Update the style transformer network
            self.opt.zero_grad()
            loss_clf.backward()
            netC.opt.step()
            self.opt.step()

            return loss_clf

        else:
            # Update the augmenter through a validation batch
            # Calculate new weights w^t+1 to calculate the validation loss
            batch_val = next(self.val_gen)
            valimages, vallabels = batch_val['images'].to(
                self.device,
                non_blocking=True), batch_val['labels'].to(self.device,
                                                           non_blocking=True)

            # construct graph
            loss_clf.backward(create_graph=True, retain_graph=True)

            # for p in netC.net.parameters():
            #     p.requires_grad = False  # freeze C

            self.w_t_1 = OrderedDict()
            lr = ut.adjust_learning_rate_netC(netC.opt,
                                              self.epoch,
                                              netC.lr_init,
                                              netC.model_dict['name'],
                                              netC.dataset.name,
                                              return_lr=True)  # get step size

            if netC.opt.defaults['momentum']:
                # for name in self.moms:
                #     self.moms[name].detach_()

                for (name, p) in netC.net.named_parameters():
                    p.requires_grad = False
                    self.moms[name].detach_()
                    self.moms[
                        name] = netC.opt.defaults['momentum'] * self.moms[
                            name] + p.grad  # update momentums
                    self.w_t_1[name] = p - lr * self.moms[
                        name]  # compute future weights

            else:
                for (name, p) in netC.net.named_parameters():
                    p.requires_grad = False
                    self.w_t_1[
                        name] = p - lr * p.grad  # compute future weights

            # Calculate validation loss
            valoutput = netC.net(valimages, params=self.w_t_1)
            loss_aug = F.cross_entropy(valoutput, vallabels, reduction='mean')
            self.opt.zero_grad()
            loss_aug.backward()
            self.opt.step()
            del self.w_t_1

            # After gradient is computed for A, unfreeze C
            for p in netC.net.parameters():
                p.requires_grad = True
            netC.opt.step()

            del images
            del labels
            gc.collect()

            return float(loss_clf.item()), transformations

    def __call__(self, img):
        img = img.unsqueeze(0)
        img, _ = self.net.forward(img)
        img = img.squeeze(0)
        return img
Esempio n. 8
0
    elif args.use_cuda:
        raise RuntimeError(
            'You are using GPU mode, but GPUs are not available!')

    # construct model and optimizer
    args.image_len = 28 if args.train_data == 'omniglot' else 84
    args.out_channels, _ = get_outputs_c_h(args.backbone, args.image_len)

    model = ConvolutionalNeuralNetwork(args.backbone, args.out_channels,
                                       args.num_ways)

    if args.use_cuda:
        os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
        num_gpus = torch.cuda.device_count()
        if args.multi_gpu:
            model = DataParallel(model)

        model = model.cuda()

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=0.0005)

    # training from the checkpoint
    if args.resume and args.resume_folder is not None:
        # load checkpoint
        checkpoint_path = os.path.join(args.resume_folder, ('_'.join([
            args.model_name, args.train_data, args.test_data, args.backbone,
            'max_acc'
        ]) + '_checkpoint.pt.tar'))  # tag='max_acc' can be changed
        state = torch.load(checkpoint_path)