def main(): parser = argparse.ArgumentParser() parser.add_argument('raw_path') parser.add_argument('model_path') args = parser.parse_args() raw_path = args.raw_path model_path = args.model_path mask_path = raw_to_mask_path(raw_path) raw = imread(raw_path) mask = imread(mask_path) y_min, x_min, y_max, x_max = mask_to_roi(mask) im = raw[y_min:y_max, x_min:x_max] model = VGG_mini_ABN() serializers.load_hdf5(model_path, model) im = resize(im, (128, 128), preserve_range=True) x_data = np.array([im_to_blob(im)], dtype=np.float32) x = Variable(x_data, volatile=True) model.train = False y = model(x) y_data = y.data print(OBJECT_CLASSES[np.argmax(y_data[0])])
def main(): parser = argparse.ArgumentParser() parser.add_argument( 'supervised_or_not', type=str, choices=['supervised', 'unsupervised'], help='do supervised or unsupervised training') parser.add_argument('--epoch', type=int, default=50, help='number of recursion (default: 50)') parser.add_argument('--no-logging', action='store_true', help='logging to tmp dir') parser.add_argument('--save-interval', type=int, default=None, help='save interval of x and x_hat') parser.add_argument('-m', '--model', required=True, help='name of model') args = parser.parse_args() n_epoch = args.epoch save_interval = args.save_interval is_supervised = True if args.supervised_or_not == 'supervised' else False on_gpu = True is_pipeline = False batch_size = 10 save_encoded = False crop_roi = False optimizers = [O.Adam()] if is_supervised: if args.model == 'VGG_mini_ABN': from apc_od.models import VGG_mini_ABN model = VGG_mini_ABN() if on_gpu: model.to_gpu() optimizers[0].setup(model) crop_roi = True elif args.model == 'CAEOnesRoiVGG': from apc_od.pipeline import CAEOnesRoiVGG is_pipeline = True batch_size = 10 initial_roi = np.array([0, 0, 356, 534]) logging.info('initial_roi: {}'.format(initial_roi)) initial_roi = roi_preprocess(initial_roi) # setup model model = CAEOnesRoiVGG(initial_roi=initial_roi, learning_rate=0.2, learning_n_sample=300) if on_gpu: model.to_gpu() optimizers = [O.Adam(), O.Adam()] optimizers[0].setup(model.cae_ones1) optimizers[1].setup(model.vgg2) # load trained models serializers.load_hdf5(os.path.join(here, 'cae_ones_model.h5'), model.cae_ones1) serializers.load_hdf5(os.path.join(here, 'vgg_model.h5'), model.vgg2) # load optimizers state serializers.load_hdf5(os.path.join(here, 'cae_ones_optimizer.h5'), optimizers[0]) serializers.load_hdf5(os.path.join(here, 'vgg_optimizer.h5'), optimizers[1]) else: sys.stderr.write('Unsupported model: {}\n'.format(args.model)) sys.exit(1) else: # unsupervised if args.model == 'CAE': from apc_od.models import CAE save_encoded = True model = CAE() if on_gpu: model.to_gpu() optimizers[0].setup(model) elif args.model == 'CAEOnes': from apc_od.models import CAEOnes model = CAEOnes() if on_gpu: model.to_gpu() optimizers[0].setup(model) elif args.model == 'CAEPool': from apc_od.models import CAEPool save_encoded = True model = CAEPool() if on_gpu: model.to_gpu() optimizers[0].setup(model) elif args.model == 'StackedCAE': from apc_od.models import StackedCAE save_encoded = True model = StackedCAE() if on_gpu: model.to_gpu() optimizers[0].setup(model) else: sys.stderr.write('Unsupported model: {}\n'.format(args.model)) sys.exit(1) timestamp = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S') # setup for logging if args.no_logging: import tempfile log_dir = tempfile.mkdtemp() else: log_dir = osp.join(here, '../logs/{}_{}'.format(timestamp, args.model)) log_dir = osp.realpath(osp.abspath(log_dir)) os.mkdir(log_dir) log_file = osp.join(log_dir, 'log.txt') logging.basicConfig( format='%(asctime)s [%(levelname)s] %(message)s', filename=log_file, level=logging.DEBUG, ) logging.info('args: {};'.format(args)) msg = 'logging in {};'.format(log_dir) logging.info(msg) print(msg) trainer = Trainer( optimizers=optimizers, model=model, model_name=args.model, is_supervised=is_supervised, crop_roi=crop_roi, batch_size=batch_size, log_dir=log_dir, log_file=log_file, on_gpu=on_gpu, ) trainer.main_loop( n_epoch=n_epoch, save_interval=save_interval, save_encoded=save_encoded, )