Example #1
0
	def run(self, pure=True, restore=False, test=False, valid=False, valid_size=-1):
		if not restore:
			_clear_files = ['logs', 'checkpoint', 'snapshot', 'output']
			for file in _clear_files:
				subprocess.run(['rm', '-rf', os.path.join(self.expdir, file)])
		init = self.load_init()
		dataset = 'C10+' if init is None else init.get('dataset', 'C10+')
		run_config = self.load_run_config(print_info=(not pure), dataset=dataset)
		run_config.renew_logs = False
		if valid_size > 0:
			run_config.validation_size = valid_size
		
		data_provider = get_data_provider_by_name(run_config.dataset, run_config.get_config())
		net_config, model_name = self.load_net_config(init, print_info=(not pure))
		model = get_model_by_name(model_name)(self.expdir, data_provider, run_config, net_config, pure=pure)
		start_epoch = 1
		if restore:
			model.load_model()
			epoch_info_file = '%s/checkpoint/epoch.info' % self.expdir
			if os.path.isfile(epoch_info_file):
				start_epoch = json.load(open(epoch_info_file, 'r'))['epoch']
				if not pure:
					print('start epoch: %d' % start_epoch)
		if test:
			print('Testing...')
			loss, accuracy = model.test(data_provider.test, batch_size=200)
			print('mean cross_entropy: %f, mean accuracy: %f' % (loss, accuracy))
			json.dump({'test_loss': '%s' % loss, 'test_acc': '%s' % accuracy}, open(self.output, 'w'))
		elif valid:
			print('validating...')
			loss, accuracy = model.test(data_provider.validation, batch_size=200)
			print('mean cross_entropy: %f, mean accuracy: %f' % (loss, accuracy))
			json.dump({'valid_loss': '%s' % loss, 'valid_acc': '%s' % accuracy}, open(self.output, 'w'))
		elif pure:
			model.pure_train()
			loss, accuracy = model.test(data_provider.validation, batch_size=200)
			json.dump({'valid_loss': '%s' % loss, 'valid_acc': '%s' % accuracy}, open(self.output, 'w'))
			model.save_init(self.snapshot, print_info=(not pure))
			model.save_config(self.expdir, print_info=(not pure))
		else:
			# train the model
			print('Data provider train images: ', data_provider.train.num_examples)
			model.train_all_epochs(start_epoch)
			print('Data provider test images: ', data_provider.test.num_examples)
			print('Testing...')
			loss, accuracy = model.test(data_provider.test, batch_size=200)
			print('mean cross_entropy: %f, mean accuracy: %f' % (loss, accuracy))
			json.dump({'test_loss': '%s' % loss, 'test_acc': '%s' % accuracy}, open(self.output, 'w'))
			model.save_init(self.snapshot, print_info=(not pure))
			model.save_config(self.expdir, print_info=(not pure))
		return accuracy
Example #2
0
    def get_start_net(self, copy=False):
        if self.start_net_config is None:
            # prepare start net
            print('Load start net from %s' % self.start_net_monitor.expdir)
            init = self.start_net_monitor.load_init()
            dataset = 'C10+' if init is None else init.get('dataset', 'C10+')
            run_config = self.start_net_monitor.load_run_config(
                print_info=True, dataset=dataset)
            run_config.renew_logs = False

            net_config, model_name = self.start_net_monitor.load_net_config(
                init, print_info=True)
            self.data_provider = get_data_provider_by_name(
                run_config.dataset, run_config.get_config())
            self.start_net_config = [net_config, run_config, model_name]
        if copy:
            net_config, run_config, model_name = self.start_net_config[:3]
            return [net_config.copy(), run_config.copy(), model_name]
        else:
            return self.start_net_config
Example #3
0
with open('model_params.json', 'r') as fp:
    model_params = json.load(fp)

# some default params dataset/architecture related
train_params = train_params_cifar
print("Params:")
for k, v in model_params.items():
    print("\t%s: %s" % (k, v))
print("Train params:")
for k, v in train_params.items():
    print("\t%s: %s" % (k, v))


model_params['use_Y'] = False
print("Prepare training data...")
data_provider = get_data_provider_by_name(model_params['dataset'], train_params)
print("Initialize the model..")
tf.reset_default_graph()
model = DenseNet(data_provider=data_provider, **model_params)
print("Loading trained model")
model.load_model()

print("Data provider test images: ", data_provider.test.num_examples)
print("Testing...")
loss, accuracy = model.test(data_provider.test, batch_size=30)
print("mean cross_entropy: %f, mean accuracy: %f" % (loss, accuracy))




def labels_to_one_hot(labels, n_classes=43+1):
Example #4
0
        args.bc_mode = True

    model_params = vars(args)

    if not args.train and not args.test:
        print("You should train or test your network. Please check params.")
        exit()

    # some default params dataset/architecture related
    train_params = get_train_params_by_name(args.dataset)
    print("Params:")
    for k, v in model_params.items():
        print("\t%s: %s" % (k, v))
    print("Train params:")
    for k, v in train_params.items():
        print("\t%s: %s" % (k, v))

    print("Prepare training data...")
    data_provider = get_data_provider_by_name(args.dataset, train_params)
    print("Initialize the model..")
    model = DenseNet(data_provider=data_provider, **model_params)
    if args.train:
        print("Data provider train images: ", data_provider.train.num_examples)
        model.train_all_epochs(train_params)
    if args.test:
        if not args.train:
            model.load_model()
        print("Data provider test images: ", data_provider.test.num_examples)
        print("Testing...")
        loss, accuracy = model.test(data_provider.test, batch_size=200)
        print("mean cross_entropy: %f, mean accuracy: %f" % (loss, accuracy))
Example #5
0
    if run_config.dataset in ['C10+', 'C100+']:
        net_config['keep_prob'] = 1.0
    if standard_net_config_cifar['model_type'] == 'DenseNet':
        net_config['reduction'] = 1.0
    if args.test: args.load_model = True

    # print configurations
    print('Run config:')
    for k, v in run_config.get_config().items():
        print('\t%s: %s' % (k, v))
    print('Network config:')
    for k, v in net_config.items():
        print('\t%s: %s' % (k, v))

    print('Prepare training data...')
    data_provider = get_data_provider_by_name(run_config.dataset,
                                              run_config.get_config())

    # set net config
    net_config = DenseNetConfig().set_standard_dense_net(
        data_provider=data_provider, **net_config)
    print('Initialize the model...')
    model = DenseNet(args.path, data_provider, run_config, net_config)

    # save configs
    if args.save_config:
        model.save_config(args.path)

    if args.load_model: model.load_model()
    if args.test:
        # test
        print('Data provider test images: ', data_provider.test.num_examples)
    'dataset': 'C100+',
    'should_save_model': True,
    'reduction': 0.5,
    'renew_logs': False,
    'embedding_dim': 64,
    'display_iter': 100,
    'save_path':
    '/home/weilin/Downloads/densenet/saves/DenseNet-BC_growth_rate=12_depth=40_dataset_C100+/metric_learning.chkpt',
    'logs_path':
    '/home/weilin/Downloads/densenet/saves/DenseNet-BC_growth_rate=12_depth=40_dataset_C100+/metric_learning',
    'pretrained_model':
    '/home/weilin/Downloads/densenet/pretrained_model/DenseNet-BC_growth_rate=12_depth=40_dataset_C100+/model.chkpt',
    'margin_multiplier': 1.0,
}

data_provider = get_data_provider_by_name(densenet_params['dataset'],
                                          train_params_cifar)
model = DenseNet(data_provider, densenet_params)

training = True
feature_extract = False

if training:
    print("Data provider train images: ", data_provider.train.num_examples)
    model.load_pretrained_model()
    model.train_all_epochs(train_params_cifar)

if feature_extract:
    model.load_model()
    print("Data provider test images: ", data_provider.test.num_examples)
    feature_embeddings, gt_labels = model.feature_extracting(
        data_provider.test, batch_size=100)
Example #7
0
from MLP_cifar import MLP
from data_providers.utils import get_data_provider_by_name

N_CLASSES = 10
BATCH_SIZE = 10
MOMENTUM = 0.9
WEIGHT_DECAY = 0.0001

train_params = {
    'batch_size': BATCH_SIZE,
    'validation_set': False,
    'shuffle': 'every_epoch',
    'normalization': 'by_chanels',
    'save_path': os.path.join('.', 'data', 'cifar')
}
data_provider = get_data_provider_by_name('C10', train_params)

#Sets the threshold for what messages will be logged.
old_v = tf.logging.get_verbosity()
# able to set the logging verbosity to either DEBUG, INFO, WARN, ERROR, or FATAL. Here its ERROR
tf.logging.set_verbosity(tf.logging.ERROR)
tf.logging.set_verbosity(old_v)


hyper_param = {
    'learning_rate': 0.01,
    'keep_prob': 1.0,
    'weight_decay': WEIGHT_DECAY,
    'momentum': MOMENTUM,
    'state_switch':[1]*50
}
Example #8
0
    def build_data_provider(self):
        if self.use_torch_data_loader:
            if self.dataset == 'C10+':
                mean = [x / 255.0 for x in [125.3, 123.0, 113.9]]
                stdv = [x / 255.0 for x in [63.0, 62.1, 66.7]]

                train_transforms = transforms.Compose([
                    transforms.RandomCrop(32, padding=4),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=mean, std=stdv),
                ])
                if self.cutout:
                    train_transforms.transforms.append(
                        Cutout(n_holes=self.cutout_n_holes,
                               length=self.cutout_size))

                test_transforms = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize(mean=mean, std=stdv),
                ])
                data_path = Cifar10DataProvider.default_data_path()
                train_set = datasets.CIFAR10(data_path,
                                             train=True,
                                             transform=train_transforms,
                                             download=True)
                test_set = datasets.CIFAR10(data_path,
                                            train=False,
                                            transform=test_transforms,
                                            download=False)
                if self.valid_size is not None:
                    valid_set = datasets.CIFAR10(data_path,
                                                 train=True,
                                                 transform=test_transforms,
                                                 download=True)
                    np.random.seed(
                        DataProvider.SEED
                    )  # set random seed before sampling validation set

                    indices = np.random.permutation(len(train_set))
                    train_indices = indices[self.valid_size:]
                    train_sampler = torch.utils.data.sampler.SubsetRandomSampler(
                        train_indices)
                    valid_indices = indices[:self.valid_size]
                    valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(
                        valid_indices)

                    train_loader = torch.utils.data.DataLoader(
                        train_set,
                        batch_size=self.batch_size,
                        sampler=train_sampler,
                        pin_memory=cuda_available(),
                        num_workers=1,
                        drop_last=self.drop_last)
                    valid_loader = torch.utils.data.DataLoader(
                        valid_set,
                        batch_size=self.batch_size,
                        sampler=valid_sampler,
                        pin_memory=cuda_available(),
                        num_workers=1,
                        drop_last=False)
                else:
                    train_loader = torch.utils.data.DataLoader(
                        train_set,
                        batch_size=self.batch_size,
                        shuffle=True,
                        pin_memory=cuda_available(),
                        num_workers=1,
                        drop_last=self.drop_last)
                    valid_loader = None

                test_loader = torch.utils.data.DataLoader(
                    test_set,
                    batch_size=100,
                    shuffle=False,
                    pin_memory=cuda_available(),
                    num_workers=1,
                    drop_last=False)
                if valid_loader is None:
                    valid_loader = test_loader
            else:
                raise NotImplementedError
        else:
            data_provider = get_data_provider_by_name(self.dataset,
                                                      self.get_config())
            train_loader = data_provider.train
            valid_loader = data_provider.validation
            test_loader = data_provider.test

            train_loader.batch_size = self.batch_size
            valid_loader.batch_size = self.batch_size
            test_loader.batch_size = 100
        return train_loader, valid_loader, test_loader