Exemplo n.º 1
0
def get_model(device):
    """
    :param device: instance of torch.device
    :return: An instance of torch.nn.Module
    """
    num_classes = 2
    if config["dataset"] == "Cifar100":
        num_classes = 100
    elif config["dataset"] == "Cifar10":
        num_classes = 10

    model = {
        "vgg11": lambda: models.VGG("VGG11", num_classes, batch_norm=False),
        "vgg11_bn": lambda: models.VGG("VGG11", num_classes, batch_norm=True),
        "vgg13": lambda: models.VGG("VGG13", num_classes, batch_norm=False),
        "vgg13_bn": lambda: models.VGG("VGG13", num_classes, batch_norm=True),
        "vgg16": lambda: models.VGG("VGG16", num_classes, batch_norm=False),
        "vgg16_bn": lambda: models.VGG("VGG16", num_classes, batch_norm=True),
        "vgg19": lambda: models.VGG("VGG19", num_classes, batch_norm=False),
        "vgg19_bn": lambda: models.VGG("VGG19", num_classes, batch_norm=True),
        "resnet10": lambda: models.ResNet10(num_classes=num_classes),
        "resnet18": lambda: models.ResNet18(num_classes=num_classes),
        "resnet34": lambda: models.ResNet34(num_classes=num_classes),
        "resnet50": lambda: models.ResNet50(num_classes=num_classes),
        "resnet101": lambda: models.ResNet101(num_classes=num_classes),
        "resnet152": lambda: models.ResNet152(num_classes=num_classes),
        "bert": lambda: models.BertImage(config, num_classes=num_classes),
    }[config["model"]]()

    model.to(device)
    if device == "cuda":
        model = torch.nn.DataParallel(model)
        torch.backends.cudnn.benchmark = True

    return model
Exemplo n.º 2
0
def create_net(num_classes, dnn='resnet20', **kwargs):
    ext = None
    if dnn in ['resnet20', 'resnet56', 'resnet110']:
        net = models.__dict__[dnn](num_classes=num_classes)
    elif dnn == 'resnet50':
        net = models.__dict__['resnet50'](num_classes=num_classes)
    elif dnn == 'mnistnet':
        net = MnistNet()
    elif dnn == 'mnistflnet':
        net = MnistFLNet()
    elif dnn == 'cifar10flnet':
        net = Cifar10FLNet()
    elif dnn == 'vgg16':
        net = models.VGG(dnn.upper())
    elif dnn == 'alexnet':
        net = torchvision.models.alexnet()
    elif dnn == 'lstman4':
        net, ext = models.LSTMAN4(datapath=kwargs['datapath'])
    elif dnn == 'lstm':
        net = lstmpy.lstm(vocab_size=kwargs['vocab_size'], batch_size=kwargs['batch_size'])

    else:
        errstr = 'Unsupport neural network %s' % dnn
        logger.error(errstr)
        raise errstr 
    return net, ext
Exemplo n.º 3
0
def create_net(num_classes, dnn='resnet20', **kwargs):
    ext = None
    if dnn in ['resnet20', 'resnet56', 'resnet110']:
        net = models.__dict__[dnn](num_classes=num_classes)
    elif dnn == 'resnet50':
        net = torchvision.models.resnet50(num_classes=num_classes)
    elif dnn == 'resnet101':
        net = torchvision.models.resnet101(num_classes=num_classes)
    elif dnn == 'resnet152':
        net = torchvision.models.resnet152(num_classes=num_classes)
    elif dnn == 'densenet121':
        net = torchvision.models.densenet121(num_classes=num_classes)
    elif dnn == 'densenet161':
        net = torchvision.models.densenet161(num_classes=num_classes)
    elif dnn == 'densenet201':
        net = torchvision.models.densenet201(num_classes=num_classes)
    elif dnn == 'inceptionv4':
        net = models.inceptionv4(num_classes=num_classes)
    elif dnn == 'inceptionv3':
        net = torchvision.models.inception_v3(num_classes=num_classes)
    elif dnn == 'vgg16i':  # vgg16 for imagenet
        net = torchvision.models.vgg16(num_classes=num_classes)
    elif dnn == 'googlenet':
        net = models.googlenet()
    elif dnn == 'mnistnet':
        net = MnistNet()
    elif dnn == 'fcn5net':
        net = models.FCN5Net()
    elif dnn == 'lenet':
        net = models.LeNet()
    elif dnn == 'lr':
        net = models.LinearRegression()
    elif dnn == 'vgg16':
        net = models.VGG(dnn.upper())
    elif dnn == 'alexnet':
        #net = models.AlexNet()
        net = torchvision.models.alexnet()
    elif dnn == 'lstman4':
        net, ext = models.LSTMAN4(datapath=kwargs['datapath'])
    elif dnn == 'lstm':
        # model = lstm(embedding_dim=args.hidden_size, num_steps=args.num_steps, batch_size=args.batch_size,
        #              vocab_size=vocab_size, num_layers=args.num_layers, dp_keep_prob=args.dp_keep_prob)
        net = lstmpy.lstm(vocab_size=kwargs['vocab_size'],
                          batch_size=kwargs['batch_size'])

    else:
        errstr = 'Unsupport neural network %s' % dnn
        logger.error(errstr)
        raise errstr
    return net, ext
Exemplo n.º 4
0
def config_net(net_name="VGG"):
    assert net_name in __all_models__, "Unimplemented architecture"
    if net_name == "VGG":
        return models.VGG("VGG19")
    elif net_name == "ResNet":
        return models.ResNet18()
    elif net_name == "ResNeXt":
        return models.ResNeXt29_2x64d()
    elif net_name == "MobileNet":
        return models.MobileNetV2()
    elif net_name == "DenseNet":
        return models.DenseNet121()
    elif net_name == "DPN":
        return models.DPN92()
    elif net_name == "EfficientNet":
        return models.EfficientNetB0()
Exemplo n.º 5
0
def create_net(num_classes, dnn='resnet20', **kwargs):
    ext = None
    if dnn in ['resnet20', 'resnet56', 'resnet110']:
        net = models.__dict__[dnn](num_classes=num_classes)
    elif dnn == 'resnet50':
        #net = models.__dict__['resnet50'](num_classes=num_classes)
        net = torchvision.models.resnet50(num_classes=num_classes)
    elif dnn == 'inceptionv4':
        net = models.inceptionv4(num_classes=num_classes)
    elif dnn == 'inceptionv3':
        net = torchvision.models.inception_v3(num_classes=num_classes)
    elif dnn == 'vgg16i':  # vgg16 for imagenet
        net = torchvision.models.vgg16(num_classes=num_classes)
    elif dnn == 'googlenet':
        net = models.googlenet()
    elif dnn == 'mnistnet':
        net = MnistNet()
    elif dnn == 'fcn5net':
        net = models.FCN5Net()
    elif dnn == 'lenet':
        net = models.LeNet()
    elif dnn == 'lr':
        net = models.LinearRegression()
    elif dnn == 'vgg16':
        net = models.VGG(dnn.upper())
    elif dnn == 'alexnet':
        net = torchvision.models.alexnet()
    elif dnn == 'lstman4':
        net, ext = models.LSTMAN4(datapath=kwargs['datapath'])
    elif dnn == 'lstm':
        net = lstmpy.lstm(vocab_size=kwargs['vocab_size'],
                          batch_size=kwargs['batch_size'])

    else:
        errstr = 'Unsupport neural network %s' % dnn
        logger.error(errstr)
        raise errstr
    return net, ext
Exemplo n.º 6
0
def setup_and_run(args, criterion, device, train_loader, test_loader,
                  val_loader, logging, results):
    global BEST_ACC
    print("\n#### Running REF ####")

    # architecture
    if args.architecture == "MLP":
        model = models.MLP(args.input_dim, args.hidden_dim,
                           args.output_dim).to(device)
    elif args.architecture == "LENET300":
        model = models.LeNet300(args.input_dim, args.output_dim).to(device)
    elif args.architecture == "LENET5":
        model = models.LeNet5(args.input_channels, args.im_size,
                              args.output_dim).to(device)
    elif "VGG" in args.architecture:
        assert (args.architecture == "VGG11" or args.architecture == "VGG13"
                or args.architecture == "VGG16"
                or args.architecture == "VGG19")
        model = models.VGG(args.architecture, args.input_channels,
                           args.im_size, args.output_dim).to(device)
    elif args.architecture == "RESNET18":
        model = models.ResNet18(args.input_channels, args.im_size,
                                args.output_dim).to(device)
    elif args.architecture == "RESNET34":
        model = models.ResNet34(args.input_channels, args.im_size,
                                args.output_dim).to(device)
    elif args.architecture == "RESNET50":
        model = models.ResNet50(args.input_channels, args.im_size,
                                args.output_dim).to(device)
    elif args.architecture == "RESNET101":
        model = models.ResNet101(args.input_channels, args.im_size,
                                 args.output_dim).to(device)
    elif args.architecture == "RESNET152":
        model = models.ResNet152(args.input_channels, args.im_size,
                                 args.output_dim).to(device)
    else:
        print('Architecture type "{0}" not recognized, exiting ...'.format(
            args.architecture))
        exit()

    # optimizer
    if args.optimizer == "ADAM":
        optimizer = optim.Adam(model.parameters(),
                               lr=args.learning_rate,
                               weight_decay=args.weight_decay)
    elif args.optimizer == "SGD":
        optimizer = optim.SGD(
            model.parameters(),
            lr=args.learning_rate,
            momentum=args.momentum,
            nesterov=args.nesterov,
            weight_decay=args.weight_decay,
        )
    else:
        print('Optimizer type "{0}" not recognized, exiting ...'.format(
            args.optimizer))
        exit()

    # lr-scheduler
    if args.lr_decay == "STEP":
        scheduler = optim.lr_scheduler.StepLR(optimizer,
                                              step_size=1,
                                              gamma=args.lr_scale)
    elif args.lr_decay == "EXP":
        scheduler = optim.lr_scheduler.ExponentialLR(optimizer,
                                                     gamma=args.lr_scale)
    elif args.lr_decay == "MSTEP":
        x = args.lr_interval.split(",")
        lri = [int(v) for v in x]
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                                   milestones=lri,
                                                   gamma=args.lr_scale)
        args.lr_interval = 1  # lr_interval handled in scheduler!
    else:
        print('LR decay type "{0}" not recognized, exiting ...'.format(
            args.lr_decay))
        exit()

    init_weights(model, xavier=True)
    logging.info(model)
    num_parameters = sum([l.nelement() for l in model.parameters()])
    logging.info("Number of parameters: %d", num_parameters)

    start_epoch = -1
    iters = 0  # total no of iterations, used to do many things!
    # optionally resume from a checkpoint
    if args.eval:
        logging.info('Loading checkpoint file "{0}" for evaluation'.format(
            args.eval))
        if not os.path.isfile(args.eval):
            print(
                'Checkpoint file "{0}" for evaluation not recognized, exiting ...'
                .format(args.eval))
            exit()
        checkpoint = torch.load(args.eval)
        model.load_state_dict(checkpoint["state_dict"])

    elif args.resume:
        checkpoint_file = args.resume
        logging.info('Loading checkpoint file "{0}" to resume'.format(
            args.resume))
        if not os.path.isfile(checkpoint_file):
            print('Checkpoint file "{0}" not recognized, exiting ...'.format(
                checkpoint_file))
            exit()
        checkpoint = torch.load(checkpoint_file)
        start_epoch = checkpoint["epoch"]
        assert args.architecture == checkpoint["architecture"]
        model.load_state_dict(checkpoint["state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        scheduler.load_state_dict(checkpoint["scheduler"])
        BEST_ACC = checkpoint["best_acc1"]
        iters = checkpoint["iters"]
        logging.debug("best_acc1: {0}, iters: {1}".format(BEST_ACC, iters))

    if not args.eval:
        logging.info("Training...")
        model.train()
        st = timer()

        for e in range(start_epoch + 1, args.num_epochs):
            for i, (data, target) in enumerate(train_loader):
                l = train_step(model, device, data, target, optimizer,
                               criterion)
                if i % args.log_interval == 0:
                    acc1, acc5 = evaluate(args,
                                          model,
                                          device,
                                          val_loader,
                                          training=True)
                    logging.info(
                        "Epoch: {0},\t Iter: {1},\t Loss: {loss:.5f},\t Val-Acc1: {acc1:.2f} "
                        "(Best: {best:.2f}),\t Val-Acc5: {acc5:.2f}".format(
                            e, i, loss=l, acc1=acc1, best=BEST_ACC, acc5=acc5))

                if iters % args.lr_interval == 0:
                    lr = args.learning_rate
                    for param_group in optimizer.param_groups:
                        lr = param_group["lr"]
                    scheduler.step()
                    for param_group in optimizer.param_groups:
                        if lr != param_group["lr"]:
                            logging.info("lr: {0}".format(
                                param_group["lr"]))  # print if changed
                iters += 1

            # save checkpoint
            acc1, acc5 = evaluate(args,
                                  model,
                                  device,
                                  val_loader,
                                  training=True)
            results.add(
                epoch=e,
                iteration=i,
                train_loss=l,
                val_acc1=acc1,
                best_val_acc1=BEST_ACC,
            )
            util.save_checkpoint(
                {
                    "epoch": e,
                    "architecture": args.architecture,
                    "state_dict": model.state_dict(),
                    "optimizer": optimizer.state_dict(),
                    "scheduler": scheduler.state_dict(),
                    "best_acc1": BEST_ACC,
                    "iters": iters,
                },
                is_best=False,
                path=args.save_dir,
            )
            results.save()

        et = timer()
        logging.info("Elapsed time: {0} seconds".format(et - st))

        acc1, acc5 = evaluate(args, model, device, val_loader, training=True)
        logging.info(
            "End of training, Val-Acc: {acc1:.2f} (Best: {best:.2f}), Val-Acc5: {acc5:.2f}"
            .format(acc1=acc1, best=BEST_ACC, acc5=acc5))
        # load saved model
        saved_model = torch.load(args.save_name)
        model.load_state_dict(saved_model["state_dict"])
    # end of training

    # eval-set
    if args.eval_set != "TRAIN" and args.eval_set != "TEST":
        print('Evaluation set "{0}" not recognized ...'.format(args.eval_set))

    logging.info("Evaluating REF on the {0} set...".format(args.eval_set))
    st = timer()
    if args.eval_set == "TRAIN":
        acc1, acc5 = evaluate(args, model, device, train_loader)
    else:
        acc1, acc5 = evaluate(args, model, device, test_loader)
    et = timer()
    logging.info("Accuracy: top-1: {acc1:.2f}, top-5: {acc5:.2f}%".format(
        acc1=acc1, acc5=acc5))
    logging.info("Elapsed time: {0} seconds".format(et - st))
Exemplo n.º 7
0
p = argparse.ArgumentParser()
p.add_argument('--input', '-i', default='images/bird.png')
p.add_argument('--gpu', '-g', type=int, default=-1)
p.add_argument('--arch', '-a', choices=['alex', 'vgg'], default='alex')
p.add_argument('--mask', '-m', action='store_true')
args = p.parse_args()

if __name__ == '__main__':
    if args.arch == 'alex':
        model = models.Alex()
        layers = [
            'conv1', 'conv2', 'conv3', 'conv4', 'conv5', 'fc6', 'fc7', 'fc8'
        ]
    elif args.arch == 'vgg':
        model = models.VGG()
        layers = [
            'conv1_2', 'conv2_2', 'conv3_3', 'conv4_3', 'conv5_3', 'fc6',
            'fc7', 'fc8'
        ]

    if args.gpu >= 0:
        chainer.cuda.get_device_from_id(args.gpu).use()
        model.to_gpu()

    src = cv2.imread(args.input, 1)
    src = cv2.resize(src, (model.size, model.size))
    src = src.astype(np.float32) - np.float32([103.939, 116.779, 123.68])
    src = src.transpose(2, 0, 1)[np.newaxis, :, :, :]
    src = model.xp.array(src)
Exemplo n.º 8
0
            if args.cuda:
                model_train.cuda()
                model_test.cuda()

            if args.pretrained:
                if args.evaluate:
                    model_test.load_state_dict(torch.load(args.pretrained))
                else:
                    model_train.load_state_dict(torch.load(args.pretrained))
                    binop_train = binop_train(model_train)
            else:
                binop_train = binop_train(model_train)

        else:
            if 'VGG' in name:
                model_ori = models.VGG(name)
            elif 'NIN' in name:
                model_ori = models.NIN()
            elif "RESNET18" in name:
                pass
            if args.cuda:
                model_ori.cuda()

            if args.pretrained:
                model_ori.load_state_dict(torch.load(args.pretrained))

    else:
        print('ERROR: specified arch is not suppported')
        exit()

    param_dict = dict(
Exemplo n.º 9
0
                for i in range(len(cfg)):
                    cfg[i] = int(cfg[i] * (1 - args.pruning_ratio))
                    temp_cfg[i] = cfg[i] * args.depth_wide[1]

        elif args.target == 'ip':
            if args.arch == 'LeNet_300_100':
                cfg = [300, 100]
                for i in range(len(cfg)):
                    cfg[i] = round(cfg[i] * (1 - args.pruning_ratio))
                temp_cfg = cfg
            pass

    # generate the model
    if args.arch == 'VGG':
        model = models.VGG(num_classes, cfg=cfg)
    elif args.arch == 'LeNet_300_100':
        model = models.LeNet_300_100(bias_flag=True, cfg=cfg)
    elif args.arch == 'ResNet':
        model = models.ResNet(int(args.depth_wide), num_classes, cfg=cfg)
    elif args.arch == 'WideResNet':
        model = models.WideResNet(args.depth_wide[0],
                                  num_classes,
                                  widen_factor=args.depth_wide[1],
                                  cfg=cfg)
    else:
        pass

    if args.cuda:
        model.cuda()
Exemplo n.º 10
0
# Model
if args.resume:
    # Load checkpoint.
    print('==> Resuming from checkpoint..')
    assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
    checkpoint = torch.load(model_path)
    net = checkpoint['net']
    best_acc = checkpoint['acc']
    start_epoch = checkpoint['epoch']
else:
    print('==> Building model..')

    model_name = args.model  # type: str
    if model_name.startswith('VGG') or model_name.startswith('vgg'):
        net = models.VGG(model_name.upper())
    else:
        net = getattr(models, model_name)()
    # net = VGG('VGG19')
    # net = ResNet18()
    # net = PreActResNet18()
    # net = GoogLeNet()
    # net = DenseNet121()
    # net = ResNeXt29_2x64d()
    # net = MobileNet()
    # net = MobileNetV2()
    # net = DPN92()
    # net = ShuffleNetG2()
    # net = SENet18()

criterion = nn.CrossEntropyLoss()
Exemplo n.º 11
0
def setup_and_run(args, criterion, device, train_loader, test_loader,
                  val_loader, logging, results):
    global BEST_ACC
    print('\n#### Running continuous-net ####')

    # architecture
    if 'VGG' in args.architecture:
        assert (args.architecture == 'VGG11' or args.architecture == 'VGG13'
                or args.architecture == 'VGG16'
                or args.architecture == 'VGG19')
        model = models.VGG(args.architecture, args.input_channels,
                           args.im_size, args.output_dim).to(device)
    elif args.architecture == 'RESNET18':
        model = models.ResNet18(args.input_channels, args.im_size,
                                args.output_dim).to(device)
    else:
        print 'Architecture type "{0}" not recognized, exiting ...'.format(
            args.architecture)
        exit()

    # optimizer
    if args.optimizer == 'ADAM':
        optimizer = optim.Adam(model.parameters(),
                               lr=args.learning_rate,
                               weight_decay=args.weight_decay)
    elif args.optimizer == 'SGD':
        optimizer = optim.SGD(model.parameters(),
                              lr=args.learning_rate,
                              momentum=args.momentum,
                              nesterov=args.nesterov,
                              weight_decay=args.weight_decay)
    else:
        print 'Optimizer type "{0}" not recognized, exiting ...'.format(
            args.optimizer)
        exit()

    # lr-scheduler
    if args.lr_decay == 'STEP':
        scheduler = optim.lr_scheduler.StepLR(optimizer,
                                              step_size=1,
                                              gamma=args.lr_scale)
    elif args.lr_decay == 'EXP':
        scheduler = optim.lr_scheduler.ExponentialLR(optimizer,
                                                     gamma=args.lr_scale)
    elif args.lr_decay == 'MSTEP':
        x = args.lr_interval.split(',')
        lri = [int(v) for v in x]
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                                   milestones=lri,
                                                   gamma=args.lr_scale)
        args.lr_interval = 1  # lr_interval handled in scheduler!
    else:
        print 'LR decay type "{0}" not recognized, exiting ...'.format(
            args.lr_decay)
        exit()

    init_weights(model, xavier=True)
    logging.info(model)
    num_parameters = sum([l.nelement() for l in model.parameters()])
    logging.info("Number of parameters: %d", num_parameters)

    start_epoch = -1
    iters = 0  # total no of iterations, used to do many things!
    # optionally resume from a checkpoint
    if args.eval:
        logging.info('Loading checkpoint file "{0}" for evaluation'.format(
            args.eval))
        if not os.path.isfile(args.eval):
            print 'Checkpoint file "{0}" for evaluation not recognized, exiting ...'.format(
                args.eval)
            exit()
        checkpoint = torch.load(args.eval)
        model.load_state_dict(checkpoint['state_dict'])

    elif args.resume:
        checkpoint_file = args.resume
        logging.info('Loading checkpoint file "{0}" to resume'.format(
            args.resume))
        if not os.path.isfile(checkpoint_file):
            print 'Checkpoint file "{0}" not recognized, exiting ...'.format(
                checkpoint_file)
            exit()
        checkpoint = torch.load(checkpoint_file)
        start_epoch = checkpoint['epoch']
        assert (args.architecture == checkpoint['architecture'])
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        BEST_ACC = checkpoint['best_acc1']
        iters = checkpoint['iters']
        logging.debug('best_acc1: {0}, iters: {1}'.format(BEST_ACC, iters))

    if not args.eval:
        logging.info('Training...')
        model.train()
        st = timer()

        for e in range(start_epoch + 1, args.num_epochs):
            for i, (data, target) in enumerate(train_loader):
                l = train_step(model, device, data, target, optimizer,
                               criterion)
                if i % args.log_interval == 0:
                    acc1, acc5 = evaluate(args,
                                          model,
                                          device,
                                          val_loader,
                                          training=True)
                    logging.info(
                        'Epoch: {0},\t Iter: {1},\t Loss: {loss:.5f},\t Val-Acc1: {acc1:.2f} '
                        '(Best: {best:.2f}),\t Val-Acc5: {acc5:.2f}'.format(
                            e, i, loss=l, acc1=acc1, best=BEST_ACC, acc5=acc5))

                if iters % args.lr_interval == 0:
                    lr = args.learning_rate
                    for param_group in optimizer.param_groups:
                        lr = param_group['lr']
                    scheduler.step()
                    for param_group in optimizer.param_groups:
                        if lr != param_group['lr']:
                            logging.info('lr: {0}'.format(
                                param_group['lr']))  # print if changed
                iters += 1

            # save checkpoint
            acc1, acc5 = evaluate(args,
                                  model,
                                  device,
                                  val_loader,
                                  training=True)
            results.add(epoch=e,
                        iteration=i,
                        train_loss=l,
                        val_acc1=acc1,
                        best_val_acc1=BEST_ACC)
            util.save_checkpoint(
                {
                    'epoch': e,
                    'architecture': args.architecture,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict(),
                    'best_acc1': BEST_ACC,
                    'iters': iters
                },
                is_best=False,
                path=args.save_dir)
            results.save()

        et = timer()
        logging.info('Elapsed time: {0} seconds'.format(et - st))

        acc1, acc5 = evaluate(args, model, device, val_loader, training=True)
        logging.info(
            'End of training, Val-Acc: {acc1:.2f} (Best: {best:.2f}), Val-Acc5: {acc5:.2f}'
            .format(acc1=acc1, best=BEST_ACC, acc5=acc5))
        # load saved model
        saved_model = torch.load(args.save_name)
        model.load_state_dict(saved_model['state_dict'])
    # end of training

    # eval-set
    if args.eval_set != 'TRAIN' and args.eval_set != 'TEST':
        print 'Evaluation set "{0}" not recognized ...'.format(args.eval_set)

    logging.info('Evaluating continuous-net on the {0} set...'.format(
        args.eval_set))
    st = timer()
    if args.eval_set == 'TRAIN':
        acc1, acc5 = evaluate(args, model, device, train_loader)
    else:
        acc1, acc5 = evaluate(args, model, device, test_loader)
    et = timer()
    logging.info('Accuracy: top-1: {acc1:.2f}, top-5: {acc5:.2f}%'.format(
        acc1=acc1, acc5=acc5))
    logging.info('Elapsed time: {0} seconds'.format(et - st))
Exemplo n.º 12
0
    def build_and_train(self, **kwargs):
        config_defaults = {
            'use_mse': False,
            'use_bce': False,
            'lr': 0.01,
            'out_path': None,
            'train_batch_size': 128,
            'regularization': None,
            'regularization_start_epoch': 10,
            'l1': 0,
            'l2': 0,
            'bias_l1': 0,
            'use_scrambling': False,
            'use_overlay': False,
            'use_elliptical': False,
            'use_quadratic': False,
            'quadratic_grad_scale': 1,
            'log_strength_inc': 0.001,
            'log_strength_start': 0.001,
            'log_strength_stop': 1,
            'batch_size_multiplier': 1,
            'report_interval': 150,
            'curvature_multiplier_inc': 1e-4,
            'curvature_multiplier_start': 0,
            'curvature_multiplier_stop': 1,
            'n_epochs': 40,
            'mse_weighted': False,
        }
        unrecognized_params = [
            k for k in kwargs
            if not (k in models.VGG.config_defaults or k in config_defaults)
        ]
        assert not unrecognized_params, 'Unrecognized parameter: ' + str(
            unrecognized_params)

        conf = {**config_defaults, **kwargs}
        print('Using training service config:', conf)
        models.log_strength_inc = float(conf['log_strength_inc'])
        models.log_strength_start = float(conf['log_strength_start'])
        models.log_strength_stop = float(conf['log_strength_stop'])
        models.curvature_multiplier_inc = float(
            conf['curvature_multiplier_inc'])
        models.curvature_multiplier_start = float(
            conf['curvature_multiplier_start'])
        models.curvature_multiplier_stop = float(
            conf['curvature_multiplier_stop'])

        model_kwargs = {
            k: v
            for k, v in kwargs.items() if k in models.VGG.config_defaults
        }
        net = models.VGG(**model_kwargs)
        print(net)

        self.trainloader = torch.utils.data.DataLoader(
            self.trainset,
            shuffle=True,
            num_workers=2,
            batch_size=conf['train_batch_size'] *
            conf['batch_size_multiplier'])
        self.testloader = torch.utils.data.DataLoader(self.testset,
                                                      batch_size=256,
                                                      shuffle=False,
                                                      num_workers=2)
        # tried ADAM already: it works for ReLU but fail to train ReLog (it doesn't just overfit,
        # it increases the loss after a few epochs)
        net = net.to(self.device)
        for epoch in range(conf['n_epochs']):
            self.optimizer = optim.SGD(net.parameters(),
                                       lr=conf['lr'],
                                       momentum=0.9)
            self.train(net, epoch, conf)
            new_train_acc = self.train_acc_estimate(net, epoch, conf)
            training_has_collapsed = (
                self.last_train_acc
                and new_train_acc < 0.5 * self.last_train_acc)
            if training_has_collapsed:
                if conf['regularization'] and epoch >= conf[
                        'regularization_start_epoch']:
                    print(
                        "Training might have collapsed because of excessive regularization, "
                        "please adjust hyperparams, aborting...")
                    if conf['out_path'] and os.path.exists(conf['out_path']):
                        net = torch.load(conf['out_path'])  # recover
                    break
                else:  # attempt recovery
                    if conf['out_path'] and os.path.exists(conf['out_path']):
                        models.freeze_hyperparams = True
                        net = torch.load(conf['out_path'])  # recover
                        for _ in range(5):
                            print(
                                '\n=== Trying to overcome collapse point ===')
                            self.train(net, epoch, conf)
                        print('\nNow, retrying normal training')
                        models.freeze_hyperparams = False
                    else:
                        print(
                            "Collapse of training detected but no model available on disk"
                        )
            else:  # only test and write model in normal state
                self.last_train_acc = new_train_acc
                self.last_test_acc = self.test(net, epoch, conf)
                if conf['out_path']:
                    torch.save(net, conf['out_path'])
                    print('Model saved to %s' % conf['out_path'])
        return net
Exemplo n.º 13
0
def setup_and_run(args, criterion, device, train_loader, test_loader, val_loader, logging, results, summary_writer):
    global BEST_ACC
    print('\n#### Running binarized-net ####')

    # quantized levels
    if (not args.tanh and args.quant_levels != 2) or args.quant_levels > 3:
        print 'Quantization levels "{0}" is invalid, exiting ...'.format(args.quant_levels)
        exit()
    # for tanh, Q_l = {-1, 0, 1}, rounding if {-1: ( ,-0.5], 0: (-0.5, 0.5), 1: [0.5, )}

    if args.zeroone and args.tanh:
        print 'zeroone cannot be true while tanh is, setting zeroone False ...'
        args.zeroone = False

    # architecture
    if 'VGG' in args.architecture:
        assert(args.architecture == 'VGG11' or args.architecture == 'VGG13' or args.architecture == 'VGG16' 
                or args.architecture == 'VGG19')
        model = models.VGG(args.architecture, args.input_channels, args.im_size, args.output_dim).to(device)
    elif args.architecture == 'RESNET18':
        model = models.ResNet18(args.input_channels, args.im_size, args.output_dim).to(device)
    else:
        print 'Architecture type "{0}" not recognized, exiting ...'.format(args.architecture)
        exit()

    # optimizer
    if args.optimizer == 'ADAM':
        optimizer = optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
    elif args.optimizer == 'SGD':
        optimizer = optim.SGD(model.parameters(), lr=args.learning_rate, 
                momentum=args.momentum, nesterov=args.nesterov, weight_decay=args.weight_decay)
    else:
        print 'Optimizer type "{0}" not recognized, exiting ...'.format(args.optimizer)
        exit()
    
    # lr-scheduler
    if args.lr_decay == 'STEP':
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=args.lr_scale)
    elif args.lr_decay == 'EXP':
        scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_scale)
    elif args.lr_decay == 'MSTEP':
        x = args.lr_interval.split(',')
        lri = [int(v) for v in x]
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=lri, gamma=args.lr_scale)
        args.lr_interval = 1    # lr_interval handled in scheduler!
    else:
        print 'LR decay type "{0}" not recognized, exiting ...'.format(args.lr_decay)
        exit()

    init_weights(model, device, xavier=True)
    if not args.eval:
        logging.info(model)
    num_parameters = sum([l.nelement() for l in model.parameters()])
    if not args.eval:
        logging.info("Number of parameters: %d", num_parameters)

    start_epoch = -1
    beta = 1    # discrete forcing scalar, used only for softmax based projection
    iters = 0   # total no of iterations, used to do many things!
    amodel = auxmodel(model)
    # optionally resume from a checkpoint
    if args.eval:
        logging.info('Loading checkpoint file "{0}" for evaluation'.format(args.eval))
        if not os.path.isfile(args.eval):
            print 'Checkpoint file "{0}" for evaluation not recognized, exiting ...'.format(args.eval)
            exit()
        checkpoint = torch.load(args.eval)
        model.load_state_dict(checkpoint['state_dict'])
        beta = checkpoint['beta']
        logging.debug('beta: {0}'.format(beta))

    elif args.resume:
        checkpoint_file = args.resume
        logging.info('Loading checkpoint file "{0}" to resume'.format(args.resume))
        if not os.path.isfile(checkpoint_file):
            print 'Checkpoint file "{0}" not recognized, exiting ...'.format(checkpoint_file)
            exit()
        checkpoint = torch.load(checkpoint_file)
        start_epoch = checkpoint['epoch']
        assert(args.architecture == checkpoint['architecture'])
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        BEST_ACC = checkpoint['best_acc1']
        beta = checkpoint['beta']
        iters = checkpoint['iters']
        logging.debug('best_acc1: {0}, beta: {1}, iters: {2}'.format(BEST_ACC, beta, iters))

    batch_per_epoch = len(train_loader)

    if not args.eval:
        logging.info('Training...')
        model.train()
        st = timer()                
        for e in range(start_epoch + 1, args.num_epochs):
            for i, (data, target) in enumerate(train_loader):
                l = train_step(args, amodel, model, device, data, target, optimizer, criterion, beta=beta)
    
                if i % args.log_interval == 0:
                    acc1, acc5 = evaluate(args, amodel, model, device, val_loader, training=True, beta=beta,
                                          summary_writer=summary_writer, iterations=e*batch_per_epoch+i)
                    logging.info('Epoch: {0},\t Iter: {1},\t Loss: {loss:.5f},\t Val-Acc1: {acc1:.2f} '
                                 '(Best: {best:.2f}),\t Val-Acc5: {acc5:.2f}'.format(e, i, 
                                     loss=l, acc1=acc1, best=BEST_ACC, acc5=acc5))
    
                if iters % args.beta_interval == 0:
                    # beta = beta * args.beta_scale
                    beta = min(beta * args.beta_scale, BETAMAX)
                    optimizer.beta_mda = beta
                    logging.info('beta: {0}'.format(beta))

                if iters % args.lr_interval == 0:
                    lr = args.learning_rate
                    for param_group in optimizer.param_groups:
                        lr = param_group['lr']                        
                    scheduler.step()
                    for param_group in optimizer.param_groups:
                        if lr != param_group['lr']:
                            logging.info('lr: {0}'.format(param_group['lr']))   # print if changed
                iters += 1

            # save checkpoint
            acc1, acc5 = evaluate(args, amodel, model, device, val_loader, training=True, beta=beta)
            results.add(epoch=e, iteration=i, train_loss=l, val_acc1=acc1, best_val_acc1=BEST_ACC)
            util.save_checkpoint({'epoch': e, 'architecture': args.architecture, 'state_dict': model.state_dict(), 
                'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), 
                'best_acc1': BEST_ACC, 'iters': iters, 'beta': beta}, is_best=False, path=args.save_dir)
            results.save()
    
        et = timer()
        logging.info('Elapsed time: {0} seconds'.format(et - st))
    
        acc1, acc5 = evaluate(args, amodel, model, device, val_loader, training=True, beta=beta)
        logging.info('End of training, Val-Acc: {acc1:.2f} (Best: {best:.2f}), Val-Acc5: {acc5:.2f}'.format(acc1=acc1, 
            best=BEST_ACC, acc5=acc5))
        # load saved model
        saved_model = torch.load(args.save_name)
        model.load_state_dict(saved_model['state_dict'])
        beta = saved_model['beta']
    # end of training

    # eval-set
    if args.tanh:
        dotanh(args, model, beta=beta)
    if args.eval_set != 'TRAIN' and args.eval_set != 'TEST':
        print 'Evaluation set "{0}" not recognized ...'.format(args.eval_set)

    logging.info('Evaluating fractional binarized-net on the {0} set...'.format(args.eval_set))
    st = timer()                
    if args.eval_set == 'TRAIN':
        acc1, acc5 = evaluate(args, amodel, model, device, train_loader)
    else: 
        acc1, acc5 = evaluate(args, amodel, model, device, test_loader)
    et = timer()
    logging.info('Accuracy: top-1: {acc1:.2f}, top-5: {acc5:.2f}%'.format(acc1=acc1, acc5=acc5))
    logging.info('Elapsed time: {0} seconds'.format(et - st))

    doround(args, model)
    logging.info('Evaluating discrete binarized-net on the {0} set...'.format(args.eval_set))
    st = timer()                
    if args.eval_set == 'TRAIN':
        acc1, acc5 = evaluate(args, amodel, model, device, train_loader)
    else: 
        acc1, acc5 = evaluate(args, amodel, model, device, test_loader)
    et = timer()
    logging.info('Accuracy: top-1: {acc1:.2f}, top-5: {acc5:.2f}%'.format(acc1=acc1, acc5=acc5))
    logging.info('Elapsed time: {0} seconds'.format(et - st))