def train(): diter = dataiter.SaliencyIter() symbol = models.deconv_net(True) arg_names = symbol.list_arguments() model = mx.mod.Module(symbol=symbol, context=mx.gpu(0), data_names=('data', ), label_names=('label', )) model.bind(data_shapes=diter.provide_data, label_shapes=diter.provide_label) model.init_params(initializer=mx.init.Uniform(scale=.1)) arg_params, aux_params = model.get_params() arg_params_load = mx.nd.load(VGG_PATH) for k in arg_params_load: if k in arg_names: arg_params[k] = arg_params_load[k] model.set_params(arg_params, aux_params, allow_missing=True) #sym, arg_params, aux_params = mx.model.load_checkpoint(MODEL_PREFIX, 2) #model.set_params(arg_params, aux_params, allow_missing=True) model.fit( diter, optimizer='adam', optimizer_params={'learning_rate': 0.01}, eval_metric='mse', batch_end_callback=mx.callback.Speedometer(BATCH_SIZE, 10), epoch_end_callback=mx.callback.do_checkpoint(MODEL_PREFIX, 1), num_epoch=100, )
def tune(): symbol = model(True) arg_names = symbol.list_arguments() ttt_shapes, _, _ = symbol.infer_shape(data=(diter.batch_size, 3, 320, 240), label=(diter.batch_size, 1, 20, 15)) dataiter = diter.SaliencyIter() mod = mx.mod.Module(symbol=symbol, context=ctx, data_names=('data', ), label_names=('label', )) mod.bind(data_shapes=dataiter.provide_data, label_shapes=dataiter.provide_label) mod.init_params(initializer=mx.init.Uniform(scale=.1)) sym, arg_params, aux_params = mx.model.load_checkpoint( '../params/tune', 50) arg_params_ = {} for k in arg_params: if k in arg_names: arg_params_[k] = arg_params[k] mod.set_params(arg_params_, aux_params, allow_missing=True) mod.fit( dataiter, optimizer='adam', optimizer_params={'learning_rate': 0.001}, eval_metric='mae', batch_end_callback=mx.callback.Speedometer(diter.batch_size, 5), epoch_end_callback=mx.callback.do_checkpoint('../params/tune', 1), num_epoch=100, )
def predict(): root = 'E:/Dataset/MIT300/BenchmarkIMAGES/' symbol = model(False) dataiter = diter.SaliencyIter() mod = mx.mod.Module(symbol=symbol, context=ctx, data_names=('data', )) mod.bind(data_shapes=dataiter.provide_data) sym, arg_params, aux_params = mx.model.load_checkpoint( '../params/tune', 100) mod.set_params(arg_params, aux_params, allow_missing=True) for _, _, fs in os.walk(root): for f in fs: inp = Image.open(root + f) w, h = inp.size s_axis = min(w, h) s_axis = min(w, h) if s_axis == w: w1, h1 = int(w * (240 / float(w))), int(h * (240 / float(w))) else: w1, h1 = int(w * (240 / float(h))), int(h * (240 / float(h))) data = np.array(inp.resize((w1, h1))).transpose((2, 1, 0)) mod.forward(Batch([mx.nd.array(np.expand_dims(data, axis=0))])) out = mod.get_outputs()[0][0][0].asnumpy().transpose() out = 255 * (out - np.amin(out)) / (np.amax(out) - np.amin(out)) img = Image.fromarray(out.astype('uint8')).resize((w, h)) img = img.filter(ImageFilter.GaussianBlur(10)) factor = 255 / float(np.amax(np.array(img))) img = ImageEnhance.Brightness(img).enhance(factor).filter( ImageFilter.GaussianBlur(2)) img = img.filter(ImageFilter.GaussianBlur(25)) factor = 255 / float(np.amax(np.array(img))) img = ImageEnhance.Brightness(img).enhance(factor).filter( ImageFilter.GaussianBlur(2)) img.save('../data/mit300/%s.png' % (f.split('.')[0])) print '%s finished' % f
def test(): diter = dataiter.SaliencyIter() symbol = deconv_net() arg_names = symbol.list_arguments() arg_shapes, output_shapes, aux_shapes = symbol.infer_shape(data=(1,3,IN_W,IN_H)) aux_params = {} for name,shape in zip(arg_names,arg_shapes): print name+' : '+str(shape) model = mx.mod.Module(symbol=symbol, context=mx.cpu(), data_names=('data',), label_names=('label',)) model.bind(data_shapes=diter.provide_data, label_shapes=model._label_shapes) model.init_params() arg_params, aux_params = model.get_params() arg_params_load = mx.nd.load('../params/vgg/vgg16-pre.params') for k in arg_params_load: arg_params[k] = arg_params_load[k] model.set_params(arg_params, aux_params, allow_missing=True) fname_list = [] for _,_,f in os.walk(DATA_PATH): fname_list.extend(f) random.shuffle(fname_list) for i in range(1): try: img = np.array(Image.open(DATA_PATH+'/'+fname_list[i]).resize((IN_W,IN_H))) img = np.swapaxes(img,0,2) model.forward(Batch([mx.nd.array(img).reshape((1,3,640,480))])) pred = model.get_outputs()[0].asnumpy() # print pred.shape # print img.shape for j in range(512): Image.fromarray(pred[0][0][j].transpose().astype('uint8')).resize((640,480)).save('../pred/'+fname_list[i].split('.')[0]+str(j)+'_pred.jpg') Image.fromarray(np.swapaxes(img,0,2).astype('uint8')).resize((640,480)).save('../pred/'+fname_list[i]) print '%sth img prediction finished'%i except: traceback.print_exc() return 0
def feature(): symbol = model(False) arg_names, aux_names = symbol.list_arguments( ), symbol.list_auxiliary_states() # dataiter = mx.io.NDArrayIter(data=mx.nd.normal(shape=(1,3,640,480), ctx=ctx), label=mx.nd.normal(shape=(1,1,40,30), ctx=ctx), batch_size=1, shuffle=True, data_name='data', label_name='label') dataiter = diter.SaliencyIter() mod = mx.mod.Module(symbol=symbol, context=ctx, data_names=('data', )) mod.bind(data_shapes=dataiter.provide_data) sym, arg_params, aux_params = mx.model.load_checkpoint( '../params/vgg16', 0) arg_params_, aux_params_ = {}, {} for k in arg_params: if k in arg_names: arg_params_[k] = arg_params[k] for k in aux_params: if k in aux_names: aux_params_[k] = aux_params[k] mod.set_params(arg_params_, aux_params_, allow_missing=True) fname = 'COCO_test2014_000000005572' data = np.array(Image.open('../data/%s.jpg' % fname)).transpose((2, 0, 1)) mod.forward(Batch([mx.nd.array(np.expand_dims(data, axis=0))])) # mod.forward(dataiter.next()) out = mod.get_outputs()[1][0].asnumpy() print out.shape out = 255 * (out - np.amin(out)) / (np.amax(out) - np.amin(out)) for i, e in enumerate(out): print e.shape img = Image.fromarray(e.astype('uint8')).resize((640, 480)) # img = img.filter(ImageFilter.GaussianBlur(10)) # factor = 255/float(np.amax(np.array(img))) # img = ImageEnhance.Brightness(img).enhance(factor).filter(ImageFilter.GaussianBlur(2)) # img = img.filter(ImageFilter.GaussianBlur(25)) # factor = 255/float(np.amax(np.array(img))) # img = ImageEnhance.Brightness(img).enhance(factor).filter(ImageFilter.GaussianBlur(2)) img.save('../data/vgg_feature/%s.png' % i) print 'finished'
def test(): epoch = int(sys.argv[1]) diter = dataiter.SaliencyIter() symbol = models.deconv_net(False) arg_names = symbol.list_arguments() aux_names = symbol.list_auxiliary_states() # arg_shapes, output_shapes, aux_shapes = symbol.infer_shape(data=(1,3,IN_W,IN_H)) # aux_params = {} # for name,shape in zip(arg_names,arg_shapes): # print name+' : '+str(shape) model = mx.mod.Module(symbol=symbol, context=mx.cpu(), label_names=None) model.bind(for_training=False, data_shapes=diter._provide_data) sym, arg_params, aux_params = mx.model.load_checkpoint(MODEL_PREFIX, epoch) try: del arg_params['lossfactor'] except: pass model.set_params(arg_params, aux_params, allow_missing=True) # arg_params, aux_params = {}, {} # arg_params_load = mx.nd.load(PARAMS_PATH) # for k in arg_params_load: # if k in arg_names: # arg_params[k] = arg_params_load[k] # elif k in aux_names: # aux_params[k] = arg_params_load[k] # else: ## print 'invalid param %s with value %s...'%(k, arg_params_load[k][0:5]) # pass # model.set_params(arg_params, aux_params, allow_missing=True) fname_list = [] for _, _, f in os.walk(img_path): fname_list.extend(f) random.shuffle(fname_list) for i in range(int(sys.argv[2])): try: img = Image.open(img_path + '/' + fname_list[i]) w, h = img.size img = np.array(img) lbl = np.array( Image.open(LABEL_PATH + '/' + fname_list[i].split('.')[0] + POST_FIX).resize((w, h))) img = np.swapaxes(img, 0, 2) model.forward(Batch([mx.nd.array(img).reshape((1, 3, w, h))]), is_train=False) pred = model.get_outputs()[0].asnumpy()[0][0] pred = 254 * (pred - np.amin(pred)) / (np.amax(pred) - np.amin(pred)) Image.fromarray(pred.transpose().astype('uint8')).resize( (w, h)).convert('L').filter(ImageFilter.SMOOTH_MORE).filter( ImageFilter.GaussianBlur(5)).save( '../pred/' + fname_list[i].split('.')[0] + '.jpg') # Image.fromarray(pred.transpose().astype('uint8')).resize((w,h)).save('../pred/'+fname_list[i].split('.')[0]+'_pred.jpg') # Image.fromarray(np.swapaxes(img,0,2).astype('uint8')).resize((w,h)).save('../pred/'+fname_list[i]) # Image.fromarray(lbl.astype('uint8')).resize((w,h)).save('../pred/'+fname_list[i].split('.')[0]+'_label.jpg') print '%sth img prediction finished with pixel range of %s~%s' % ( i, np.amin(pred), np.amax(pred)) except: traceback.print_exc() return 0