Beispiel #1
0
def load_baseline_model(args):
    """

    :param args:
    :return:
    """
    if args.dataset == 'cifar10':
        num_classes = 10
        train_loader, val_loader, test_loader = data_loaders.load_cifar10(args.batch_size, val_split=True,
                                                                          augmentation=args.data_augmentation)
    elif args.dataset == 'cifar100':
        num_classes = 100
        train_loader, val_loader, test_loader = data_loaders.load_cifar100(args.batch_size, val_split=True,
                                                                           augmentation=args.data_augmentation)
    elif args.dataset == 'mnist':
        args.datasize, args.valsize, args.testsize = 100, 100, 100
        num_train = args.datasize
        if args.datasize == -1:
            num_train = 50000

        from data_loaders import load_mnist
        train_loader, val_loader, test_loader = load_mnist(args.batch_size,
                                                           subset=[args.datasize, args.valsize, args.testsize],
                                                           num_train=num_train)

    if args.model == 'resnet18':
        cnn = ResNet18(num_classes=num_classes)
    elif args.model == 'wideresnet':
        cnn = WideResNet(depth=28, num_classes=num_classes, widen_factor=10, dropRate=0.3)

    checkpoint = None
    if args.load_baseline_checkpoint:
        checkpoint = torch.load(args.load_baseline_checkpoint)
        cnn.load_state_dict(checkpoint['model_state_dict'])

    model = cnn.cuda()
    model.train()
    return model, train_loader, val_loader, test_loader, checkpoint
    train_loader, val_loader, test_loader = data_loaders.load_cifar10(
        args.batch_size, val_split=True, augmentation=args.data_augmentation)
elif args.dataset == 'cifar100':
    num_classes = 100
    train_loader, val_loader, test_loader = data_loaders.load_cifar100(
        args.batch_size, val_split=True, augmentation=args.data_augmentation)

if args.model == 'resnet18':
    cnn = ResNet18(num_classes=num_classes)
elif args.model == 'wideresnet':
    cnn = WideResNet(depth=28,
                     num_classes=num_classes,
                     widen_factor=10,
                     dropRate=0.3)

cnn = cnn.cuda()
criterion = nn.CrossEntropyLoss().cuda()

if args.optimizer == 'sgdm':
    cnn_optimizer = torch.optim.SGD(cnn.parameters(),
                                    lr=args.lr,
                                    momentum=0.9,
                                    nesterov=True,
                                    weight_decay=args.wdecay)
elif args.optimizer == 'sgd':
    cnn_optimizer = torch.optim.SGD(cnn.parameters(),
                                    lr=args.lr,
                                    weight_decay=args.wdecay)
elif args.optimizer == 'adam':
    cnn_optimizer = torch.optim.Adam(
        cnn.parameters(), lr=args.lr, weight_decay=args.wdecay
def main(args):
    harakiri = Harakiri()
    harakiri.set_max_plateau(20)
    train_loss_meter = Meter()
    val_loss_meter = Meter()
    val_accuracy_meter = Meter()
    log = JsonLogger(args.log_path, rand_folder=True)
    log.update(args.__dict__)
    state = args.__dict__
    state['exp_dir'] = os.path.dirname(log.path)
    state['start_lr'] = state['lr']
    print(state)

    imagenet_mean = [0.485, 0.456, 0.406]
    imagenet_std = [0.229, 0.224, 0.225]

    train_dataset = ImageList(args.root_folder,
                              args.train_listfile,
                              transform=transforms.Compose([
                                  transforms.Resize(256),
                                  transforms.RandomCrop(224),
                                  transforms.RandomHorizontalFlip(),
                                  transforms.ToTensor(),
                                  transforms.Normalize(imagenet_mean,
                                                       imagenet_std)
                              ]))
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               pin_memory=False,
                                               num_workers=args.num_workers)
    val_dataset = ImageList(args.root_folder,
                            args.val_listfile,
                            transform=transforms.Compose([
                                transforms.Resize(256),
                                transforms.CenterCrop(224),
                                transforms.ToTensor(),
                                transforms.Normalize(imagenet_mean,
                                                     imagenet_std)
                            ]))
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             pin_memory=False,
                                             num_workers=args.num_workers)

    if args.attention_depth == 0:
        from models.wide_resnet import WideResNet
        model = WideResNet().finetune(args.nlabels).cuda()
    else:
        from models.wide_resnet_attention import WideResNetAttention
        model = WideResNetAttention(args.nlabels, args.attention_depth,
                                    args.attention_width, args.has_gates,
                                    args.reg_weight).finetune(args.nlabels)

    # if args.load != "":
    #     net.load_state_dict(torch.load(args.load), strict=False)
    #     net = net.cuda()

    optimizer = optim.SGD([{
        'params': model.get_base_params(),
        'lr': args.lr * 0.1
    }, {
        'params': model.get_classifier_params()
    }],
                          lr=args.lr,
                          weight_decay=1e-4,
                          momentum=0.9,
                          nesterov=True)

    if args.ngpu > 1:
        model = torch.nn.DataParallel(model, range(args.ngpu)).cuda()
    else:
        model = model.cuda()
        criterion = torch.nn.NLLLoss().cuda()

    def train():
        """

        """
        model.train()
        for data, label in train_loader:
            data, label = torch.autograd.Variable(data, requires_grad=False).cuda(async=True), \
                          torch.autograd.Variable(label, requires_grad=False).cuda()
            optimizer.zero_grad()
            if args.attention_depth > 0:
                output, loss = model(data)
                if args.reg_weight > 0:
                    loss = loss.mean()
                else:
                    loss = 0
            else:
                loss = 0
                output = model(data)
            loss += F.nll_loss(output, label)
            loss.backward()
            optimizer.step()
            train_loss_meter.update(loss.data[0], data.size(0))
        state['train_loss'] = train_loss_meter.mean()

    def val():
        """

        """
        model.eval()
        for data, label in val_loader:
            data, label = torch.autograd.Variable(data, volatile=True).cuda(async=True), \
                          torch.autograd.Variable(label, volatile=True).cuda()
            if args.attention_depth > 0:
                output, loss = model(data)
            else:
                output = model(data)
            loss = F.nll_loss(output, label)
            val_loss_meter.update(loss.data[0], data.size(0))
            preds = output.max(1)[1]
            val_accuracy_meter.update((preds == label).float().sum().data[0],
                                      data.size(0))
        state['val_loss'] = val_loss_meter.mean()
        state['val_accuracy'] = val_accuracy_meter.mean()

    best_accuracy = 0
    counter = 0
    for epoch in range(args.epochs):
        train()
        val()
        harakiri.update(epoch, state['val_accuracy'])
        if state['val_accuracy'] > best_accuracy:
            counter = 0
            best_accuracy = state['val_accuracy']
            if args.save:
                torch.save(model.state_dict(),
                           os.path.join(state["exp_dir"], "model.pytorch"))
        else:
            counter += 1
        state['epoch'] = epoch + 1
        log.update(state)
        print(state)
        if (epoch + 1) in args.schedule:
            for param_group in optimizer.param_groups:
                param_group['lr'] *= 0.1
            state['lr'] *= 0.1
Beispiel #4
0
class NeuralNet:
    def __init__(self):
        if torch.cuda.is_available():
            self.device = torch.device('cuda')
        else:
            self.device = torch.device('cpu')
            print('WARNING: Found no valid GPU device - Running on CPU')

        self.model = WideResNet(DEPTH, cfg.NUM_TRANS)
        self.model.cuda(self.device)
        self.criterion = torch.nn.CrossEntropyLoss().cuda(self.device)
        self.optimizer = None

    def test(self, test_gen):
        self.model.eval()

        batch_time = self.AverageMeter('Time', ':6.3f')
        losses = self.AverageMeter('Loss', ':.4e')
        top1 = self.AverageMeter('Acc@1', ':6.2f')
        progress = self.ProgressMeter(len(test_gen),
                                      batch_time,
                                      losses,
                                      top1,
                                      prefix='Test: ')

        # switch to evaluate mode
        self.model.eval()

        with torch.no_grad():
            end = time.time()
            for i, (input, target, _) in enumerate(test_gen):
                input = input.cuda(self.device, non_blocking=True)
                target = target.cuda(self.device, non_blocking=True)

                # Compute output
                output, _ = self.model([input, target])
                loss = self.criterion(output, target)

                # Measure accuracy and record loss
                acc1 = self._accuracy(output, target, topk=(1))
                losses.update(loss.item(), input.size(0))
                top1.update(acc1[0].cpu().detach().item(), input.size(0))

                # Measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                # Print to screen
                if i % 100 == 0:
                    progress.print(i)

            # TODO: this should also be done with the ProgressMeter
            print(' * Acc@1 {top1.avg:.3f}'.format(top1=top1))

        return top1.avg

    def train(self,
              train_gen,
              test_gen,
              epochs,
              lr=0.0001,
              lr_plan=None,
              momentum=0.9,
              wd=5e-4):
        self.optimizer = torch.optim.SGD(self.model.parameters(),
                                         lr=lr,
                                         momentum=momentum,
                                         weight_decay=wd)
        #self.optimizer = torch.optim.Adam(self.model.parameters())

        for epoch in range(epochs):
            self._adjust_lr_rate(self.optimizer, epoch, lr_plan)
            print("=> Training (specific label)")
            self._train_step(train_gen, epoch, self.optimizer)
            print("=> Validation (entire dataset)")
            self.test(test_gen)

    def _train_step(self, train_gen, epoch, optimizer):
        self.model.train()

        batch_time = self.AverageMeter('Time', ':6.3f')
        data_time = self.AverageMeter('Data', ':6.3f')
        losses = self.AverageMeter('Loss', ':.4e')
        top1 = self.AverageMeter('Acc@1', ':6.2f')
        progress = self.ProgressMeter(len(train_gen),
                                      batch_time,
                                      data_time,
                                      losses,
                                      top1,
                                      prefix="Epoch: [{}]".format(epoch))

        end = time.time()
        for i, (input, target, _) in enumerate(train_gen):
            # measure data loading time
            data_time.update(time.time() - end)

            input = input.cuda(self.device, non_blocking=True)
            target = target.cuda(self.device, non_blocking=True)

            # Compute output
            output, trans_out = self.model([input, target])
            loss = self.criterion(output, target)

            # measure accuracy and record loss
            acc1 = self._accuracy(output, target, topk=(1))
            losses.update(loss.item(), input.size(0))
            top1.update(acc1[0].cpu().detach().item(), input.size(0))

            # compute gradient and do SGD step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % 100 == 0:
                progress.print(i)

    def evaluate(self, eval_gen):
        # switch to evaluate mode
        self.model.eval()

        score_func_list = []
        labels_list = []
        with torch.no_grad():

            for i, (input, target, labels) in enumerate(eval_gen):
                input = input.cuda(self.device, non_blocking=True)
                #Target- the transforamation class
                target = target.cuda(self.device, non_blocking=True)
                #The true label
                labels = labels[[
                    cfg.NUM_TRANS * x
                    for x in range(len(labels) // cfg.NUM_TRANS)
                ]].cuda(self.device, non_blocking=True)

                # Compute output
                # #TODO: Rewrite this code section, can be more efficient
                output_SM = self.model([input, target])

                target_mat = torch.zeros_like(output_SM[0])
                target_mat[range(output_SM[0].shape[0]), target] = 1
                target_SM = (target_mat * output_SM[0]).sum(dim=1).view(
                    -1, cfg.NUM_TRANS).sum(dim=1)

                score_func_list.append(1 / cfg.NUM_TRANS * target_SM)
                labels_list.append(labels)

        return torch.cat(score_func_list), torch.cat(labels_list)

    def _adjust_lr_rate(self, optimizer, epoch, lr_dict):
        if lr_dict is None:
            return

        for key, value in lr_dict.items():
            if epoch == key:
                print("=> New learning rate set of {}".format(value))
                for param_group in optimizer.param_groups:
                    param_group['lr'] = value

    def summary(self, x_size, print_it=True):
        return self.model.summary(x_size, print_it=print_it)

    def print_weights(self):
        self.model.print_weights()

    @staticmethod
    def _accuracy(output, target, topk=(1)):
        """Computes the accuracy over the k top predictions for the specified values of k"""
        with torch.no_grad():
            maxk = 1
            batch_size = target.size(0)

            _, pred = output.topk(maxk, 1, True, True)
            pred = pred.t()
            correct = pred.eq(target.view(1, -1).expand_as(pred))

            res = []

            correct_k = correct[0].view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))

            return res

    class AverageMeter(object):
        """Computes and stores the average and current value"""
        def __init__(self, name, fmt=':f'):
            self.name = name
            self.fmt = fmt
            self.reset()

        def reset(self):
            self.val = 0
            self.avg = 0
            self.sum = 0
            self.count = 0

        def update(self, val, n=1):
            self.val = val
            self.sum += val * n
            self.count += n
            self.avg = self.sum / self.count

        def __str__(self):
            fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
            return fmtstr.format(**self.__dict__)

    class ProgressMeter(object):
        def __init__(self, num_batches, *meters, prefix=""):
            self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
            self.meters = meters
            self.prefix = prefix

        def print(self, batch):
            entries = [self.prefix + self.batch_fmtstr.format(batch)]
            entries += [str(meter) for meter in self.meters]
            print('\t'.join(entries))

        def _get_batch_fmtstr(self, num_batches):
            num_digits = len(str(num_batches // 1))
            fmt = '{:' + str(num_digits) + 'd}'
            return '[' + fmt + '/' + fmt.format(num_batches) + ']'