Exemplo n.º 1
0
    def setup(self, flags):
        torch.backends.cudnn.deterministic = flags.deterministic
        print('torch.backends.cudnn.deterministic:', torch.backends.cudnn.deterministic)
        fix_all_seed(flags.seed)

        if flags.dataset == 'cifar10':
            num_classes = 10
        else:
            num_classes = 100

        if flags.model == 'densenet':
            self.network = densenet(num_classes=num_classes)
        elif flags.model == 'wrn':
            self.network = WideResNet(flags.layers, num_classes, flags.widen_factor, flags.droprate)
        elif flags.model == 'allconv':
            self.network = AllConvNet(num_classes)
        elif flags.model == 'resnext':
            self.network = resnext29(num_classes=num_classes)
        else:
            raise Exception('Unknown model.')
        self.network = self.network.cuda()

        print(self.network)
        print('flags:', flags)
        if not os.path.exists(flags.logs):
            os.makedirs(flags.logs)

        flags_log = os.path.join(flags.logs, 'flags_log.txt')
        write_log(flags, flags_log)
Exemplo n.º 2
0
    def _load_model(self):
        print('loading model...')

        if self.args.model == 'resnet':
            from models.resnet import resnet
            self.model = resnet(**self.args.model_args)
        elif self.args.model == 'densenet':
            from models.densenet import densenet
            self.model = densenet(**self.args.model_args)

        self.policies = self.model.parameters()

        #self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpus).cuda()

        if self.args.resume:
            if os.path.isfile(self.args.resume):
                print(("=> loading checkpoint '{}'".format(self.args.resume)))
                checkpoint = torch.load(self.args.resume)
                d = collections.OrderedDict()
                for key, value in checkpoint['state_dict'].items():
                    tmp = key[7:]
                    d[tmp] = value
                self.args.start_epoch = checkpoint['epoch']
                #self.model.load_state_dict(checkpoint['state_dict'])
                self.model.load_state_dict(d)
                print(("=> loaded checkpoint '{}' (epoch {})".format(
                    self.args.phase, checkpoint['epoch'])))
            else:
                print(("=> no checkpoint found at '{}'".format(
                    self.args.resume)))

        print('model load finished!')
Exemplo n.º 3
0
def get_model(model, args):
    if model == 'alexnet':
        return alexnet()
    if model == 'resnet':
        return resnet(dataset=args.dataset)
    if model == 'wideresnet':
        return WideResNet(args.layers, args.dataset == 'cifar10' and 10 or 100,
                          args.widen_factor, dropRate=args.droprate, gbn=args.gbn)
    if model == 'densenet':
        return densenet()
Exemplo n.º 4
0
        data_path + 'data.cifar100',
        train=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010))
        ])),
                                              batch_size=args.test_batch_size,
                                              shuffle=True,
                                              **kwargs)

logger.info("Loading teacher model from {}".format(args.teacher))
model_teacher = densenet.densenet(
    num_classes=num_classes,
    depth=args.depth_dense,
    block=densenet.Bottleneck,
    growthRate=args.growthRate,
    compressionRate=args.compressionRate,
    dropRate=args.drop,
)
checkpoint_path = args.teacher + "/model_best.pth.tar"
model_teacher.load_state_dict(torch.load(checkpoint_path)['state_dict'])
logger.info(model_teacher)

logger.info("Initializing student model...")
model = resnet.resnet(depth=args.depth,
                      num_classes=num_classes,
                      num_blocks=num_blocks)
logger.info(model)

if args.cuda:
    model_teacher.cuda()
Exemplo n.º 5
0
def training():
    batch_size = 64
    epochs = 200
    fine_tune_epochs = 30
    lr = 0.1

    x_train, y_train, x_test, y_test, nb_classes = load_data(args.data)
    print('x_train shape:', x_train.shape)
    print(x_train.shape[0], 'train samples')
    print(x_test.shape[0], 'test samples')

    y_train = keras.utils.to_categorical(y_train, nb_classes)
    y_test = keras.utils.to_categorical(y_test, nb_classes)

    datagen = ImageDataGenerator(horizontal_flip=True,
                                 width_shift_range=5. / 32,
                                 height_shift_range=5. / 32)
    data_iter = datagen.flow(x_train,
                             y_train,
                             batch_size=batch_size,
                             shuffle=True)

    if args.model == 'resnet':
        model = resnet.resnet(nb_classes,
                              depth=args.depth,
                              wide_factor=args.wide_factor)
        save_name = 'resnet_{}_{}_{}'.format(args.depth, args.wide_factor,
                                             args.data)
    elif args.model == 'densenet':
        model = densenet.densenet(nb_classes,
                                  args.growth_rate,
                                  depth=args.depth)
        save_name = 'densenet_{}_{}_{}'.format(args.depth, args.growth_rate,
                                               args.data)
    elif args.model == 'inception':
        model = inception.inception_v3(nb_classes)
        save_name = 'inception_{}'.format(args.data)

    elif args.model == 'vgg':
        model = vggnet.vgg(nb_classes)
        save_name = 'vgg_{}'.format(args.data)
    else:
        raise ValueError('Does not support {}'.format(args.model))

    model.summary()
    learning_rate_scheduler = LearningRateScheduler(schedule=schedule)
    if args.model == 'vgg':
        callbacks = None
        model.compile(loss='categorical_crossentropy',
                      optimizer='adam',
                      metrics=['accuracy'])
    else:
        callbacks = [learning_rate_scheduler]
        model.compile(loss='categorical_crossentropy',
                      optimizer=keras.optimizers.SGD(lr=lr,
                                                     momentum=0.9,
                                                     nesterov=True),
                      metrics=['accuracy'])
    # pre-train
    history = model.fit_generator(data_iter,
                                  steps_per_epoch=x_train.shape[0] //
                                  batch_size,
                                  epochs=epochs,
                                  callbacks=callbacks,
                                  validation_data=(x_test, y_test))
    if not os.path.exists('./results/'):
        os.mkdir('./results/')
    save_history(history, './results/', save_name)
    model.save_weights('./results/{}_weights.h5'.format(save_name))

    # prune weights
    # save masks for weight layers
    masks = {}
    layer_count = 0
    # not compress first convolution layer
    first_conv = True
    for layer in model.layers:
        weight = layer.get_weights()
        if len(weight) >= 2:
            if not first_conv:
                w = deepcopy(weight)
                tmp, mask = prune_weights(w[0],
                                          compress_rate=args.compress_rate)
                masks[layer_count] = mask
                w[0] = tmp
                layer.set_weights(w)
            else:
                first_conv = False
        layer_count += 1
    # evaluate model after pruning
    score = model.evaluate(x_test, y_test, verbose=0)
    print('val loss: {}'.format(score[0]))
    print('val acc: {}'.format(score[1]))
    # fine-tune
    for i in range(fine_tune_epochs):
        for _ in range(x_train.shape[0] // batch_size):
            X, Y = data_iter.next()
            # train on each batch
            model.train_on_batch(X, Y)
            # apply masks
            for layer_id in masks:
                w = model.layers[layer_id].get_weights()
                w[0] = w[0] * masks[layer_id]
                model.layers[layer_id].set_weights(w)
        score = model.evaluate(x_test, y_test, verbose=0)
        print('val loss: {}'.format(score[0]))
        print('val acc: {}'.format(score[1]))

    # save compressed weights
    compressed_name = './results/compressed_{}_weights'.format(args.model)
    save_compressed_weights(model, compressed_name)
Exemplo n.º 6
0
    test_data = dset.CIFAR100('~/datasets/cifarpy', train=False, transform=test_transform)
    num_classes = 100


train_loader = torch.utils.data.DataLoader(
    train_data, batch_size=args.batch_size, shuffle=True,
    num_workers=args.prefetch, pin_memory=True)
test_loader = torch.utils.data.DataLoader(
    test_data, batch_size=args.test_bs, shuffle=False,
    num_workers=args.prefetch, pin_memory=True)

# Create model
if args.model == 'densenet':
    args.decay = 0.0001
    args.epochs = 200
    net = densenet(num_classes=num_classes)
elif args.model == 'wrn':
    net = WideResNet(args.layers, num_classes, args.widen_factor, dropRate=args.droprate)
elif args.model == 'allconv':
    net = AllConvNet(num_classes)
elif args.model == 'resnext':
    args.epochs = 200
    net = resnext29(num_classes=num_classes)

state = {k: v for k, v in args._get_kwargs()}
print(state)

start_epoch = 0

# Restore model if desired
if args.load != '':
                       default=1,
                       help='wide factor for WRN')
    args = parse.parse_args()

    x_train, y_train, x_test, y_test, nb_classes = load_data(args.data)
    y_test = keras.utils.to_categorical(y_test, nb_classes)

    if args.model == 'resnet':
        model = resnet.resnet(nb_classes,
                              depth=args.depth,
                              wide_factor=args.wide_factor)
        save_name = 'resnet_{}_{}_{}'.format(args.depth, args.wide_factor,
                                             args.data)
    elif args.model == 'densenet':
        model = densenet.densenet(nb_classes,
                                  args.growth_rate,
                                  depth=args.depth)
        save_name = 'densenet_{}_{}_{}'.format(args.depth, args.growth_rate,
                                               args.data)
    elif args.model == 'inception':
        model = inception.inception_v3(nb_classes)
        save_name = 'inception_{}'.format(args.data)

    elif args.model == 'vgg':
        model = vggnet.vgg(nb_classes)
        save_name = 'vgg_{}'.format(args.data)
    else:
        raise ValueError('Does not support {}'.format(args.model))

    model.summary()