コード例 #1
0
def train():

    diter = dataiter.SaliencyIter()
    symbol = models.deconv_net(True)
    arg_names = symbol.list_arguments()

    model = mx.mod.Module(symbol=symbol,
                          context=mx.gpu(0),
                          data_names=('data', ),
                          label_names=('label', ))
    model.bind(data_shapes=diter.provide_data,
               label_shapes=diter.provide_label)

    model.init_params(initializer=mx.init.Uniform(scale=.1))
    arg_params, aux_params = model.get_params()
    arg_params_load = mx.nd.load(VGG_PATH)
    for k in arg_params_load:
        if k in arg_names:
            arg_params[k] = arg_params_load[k]
    model.set_params(arg_params, aux_params, allow_missing=True)

    #sym, arg_params, aux_params = mx.model.load_checkpoint(MODEL_PREFIX, 2)
    #model.set_params(arg_params, aux_params, allow_missing=True)

    model.fit(
        diter,
        optimizer='adam',
        optimizer_params={'learning_rate': 0.01},
        eval_metric='mse',
        batch_end_callback=mx.callback.Speedometer(BATCH_SIZE, 10),
        epoch_end_callback=mx.callback.do_checkpoint(MODEL_PREFIX, 1),
        num_epoch=100,
    )
コード例 #2
0
ファイル: foo.py プロジェクト: hallazie/LWSP
def tune():
    symbol = model(True)
    arg_names = symbol.list_arguments()
    ttt_shapes, _, _ = symbol.infer_shape(data=(diter.batch_size, 3, 320, 240),
                                          label=(diter.batch_size, 1, 20, 15))
    dataiter = diter.SaliencyIter()
    mod = mx.mod.Module(symbol=symbol,
                        context=ctx,
                        data_names=('data', ),
                        label_names=('label', ))
    mod.bind(data_shapes=dataiter.provide_data,
             label_shapes=dataiter.provide_label)
    mod.init_params(initializer=mx.init.Uniform(scale=.1))
    sym, arg_params, aux_params = mx.model.load_checkpoint(
        '../params/tune', 50)
    arg_params_ = {}
    for k in arg_params:
        if k in arg_names:
            arg_params_[k] = arg_params[k]
    mod.set_params(arg_params_, aux_params, allow_missing=True)
    mod.fit(
        dataiter,
        optimizer='adam',
        optimizer_params={'learning_rate': 0.001},
        eval_metric='mae',
        batch_end_callback=mx.callback.Speedometer(diter.batch_size, 5),
        epoch_end_callback=mx.callback.do_checkpoint('../params/tune', 1),
        num_epoch=100,
    )
コード例 #3
0
ファイル: foo.py プロジェクト: hallazie/LWSP
def predict():
    root = 'E:/Dataset/MIT300/BenchmarkIMAGES/'
    symbol = model(False)
    dataiter = diter.SaliencyIter()
    mod = mx.mod.Module(symbol=symbol, context=ctx, data_names=('data', ))
    mod.bind(data_shapes=dataiter.provide_data)
    sym, arg_params, aux_params = mx.model.load_checkpoint(
        '../params/tune', 100)
    mod.set_params(arg_params, aux_params, allow_missing=True)
    for _, _, fs in os.walk(root):
        for f in fs:
            inp = Image.open(root + f)
            w, h = inp.size
            s_axis = min(w, h)
            s_axis = min(w, h)
            if s_axis == w:
                w1, h1 = int(w * (240 / float(w))), int(h * (240 / float(w)))
            else:
                w1, h1 = int(w * (240 / float(h))), int(h * (240 / float(h)))
            data = np.array(inp.resize((w1, h1))).transpose((2, 1, 0))
            mod.forward(Batch([mx.nd.array(np.expand_dims(data, axis=0))]))
            out = mod.get_outputs()[0][0][0].asnumpy().transpose()
            out = 255 * (out - np.amin(out)) / (np.amax(out) - np.amin(out))
            img = Image.fromarray(out.astype('uint8')).resize((w, h))
            img = img.filter(ImageFilter.GaussianBlur(10))
            factor = 255 / float(np.amax(np.array(img)))
            img = ImageEnhance.Brightness(img).enhance(factor).filter(
                ImageFilter.GaussianBlur(2))
            img = img.filter(ImageFilter.GaussianBlur(25))
            factor = 255 / float(np.amax(np.array(img)))
            img = ImageEnhance.Brightness(img).enhance(factor).filter(
                ImageFilter.GaussianBlur(2))
            img.save('../data/mit300/%s.png' % (f.split('.')[0]))
            print '%s finished' % f
コード例 #4
0
def test():
	diter = dataiter.SaliencyIter()
	symbol = deconv_net()
	arg_names = symbol.list_arguments()
	arg_shapes, output_shapes, aux_shapes = symbol.infer_shape(data=(1,3,IN_W,IN_H))
	aux_params = {}
	for name,shape in zip(arg_names,arg_shapes):
		print name+' : '+str(shape)

	model = mx.mod.Module(symbol=symbol, context=mx.cpu(), data_names=('data',), label_names=('label',))
	model.bind(data_shapes=diter.provide_data, label_shapes=model._label_shapes)

	model.init_params()
	arg_params, aux_params = model.get_params()
	arg_params_load = mx.nd.load('../params/vgg/vgg16-pre.params')
	for k in arg_params_load:
		arg_params[k] = arg_params_load[k]
	model.set_params(arg_params, aux_params, allow_missing=True)

	fname_list = []
	for _,_,f in os.walk(DATA_PATH):
		fname_list.extend(f)
	random.shuffle(fname_list)
	for i in range(1):
		try:
			img = np.array(Image.open(DATA_PATH+'/'+fname_list[i]).resize((IN_W,IN_H)))
			img = np.swapaxes(img,0,2)
			model.forward(Batch([mx.nd.array(img).reshape((1,3,640,480))]))
			pred = model.get_outputs()[0].asnumpy()
			# print pred.shape
			# print img.shape
			for j in range(512):
				Image.fromarray(pred[0][0][j].transpose().astype('uint8')).resize((640,480)).save('../pred/'+fname_list[i].split('.')[0]+str(j)+'_pred.jpg')
			Image.fromarray(np.swapaxes(img,0,2).astype('uint8')).resize((640,480)).save('../pred/'+fname_list[i])
			print '%sth img prediction finished'%i
		except:
			traceback.print_exc()
			return 0
コード例 #5
0
ファイル: foo.py プロジェクト: hallazie/LWSP
def feature():
    symbol = model(False)
    arg_names, aux_names = symbol.list_arguments(
    ), symbol.list_auxiliary_states()
    # dataiter = mx.io.NDArrayIter(data=mx.nd.normal(shape=(1,3,640,480), ctx=ctx), label=mx.nd.normal(shape=(1,1,40,30), ctx=ctx), batch_size=1, shuffle=True, data_name='data', label_name='label')
    dataiter = diter.SaliencyIter()
    mod = mx.mod.Module(symbol=symbol, context=ctx, data_names=('data', ))
    mod.bind(data_shapes=dataiter.provide_data)
    sym, arg_params, aux_params = mx.model.load_checkpoint(
        '../params/vgg16', 0)
    arg_params_, aux_params_ = {}, {}
    for k in arg_params:
        if k in arg_names:
            arg_params_[k] = arg_params[k]
    for k in aux_params:
        if k in aux_names:
            aux_params_[k] = aux_params[k]
    mod.set_params(arg_params_, aux_params_, allow_missing=True)
    fname = 'COCO_test2014_000000005572'
    data = np.array(Image.open('../data/%s.jpg' % fname)).transpose((2, 0, 1))
    mod.forward(Batch([mx.nd.array(np.expand_dims(data, axis=0))]))
    # mod.forward(dataiter.next())
    out = mod.get_outputs()[1][0].asnumpy()
    print out.shape
    out = 255 * (out - np.amin(out)) / (np.amax(out) - np.amin(out))
    for i, e in enumerate(out):
        print e.shape
        img = Image.fromarray(e.astype('uint8')).resize((640, 480))
        # img = img.filter(ImageFilter.GaussianBlur(10))
        # factor = 255/float(np.amax(np.array(img)))
        # img = ImageEnhance.Brightness(img).enhance(factor).filter(ImageFilter.GaussianBlur(2))
        # img = img.filter(ImageFilter.GaussianBlur(25))
        # factor = 255/float(np.amax(np.array(img)))
        # img = ImageEnhance.Brightness(img).enhance(factor).filter(ImageFilter.GaussianBlur(2))
        img.save('../data/vgg_feature/%s.png' % i)
    print 'finished'
コード例 #6
0
def test():
    epoch = int(sys.argv[1])
    diter = dataiter.SaliencyIter()
    symbol = models.deconv_net(False)

    arg_names = symbol.list_arguments()
    aux_names = symbol.list_auxiliary_states()
    # arg_shapes, output_shapes, aux_shapes = symbol.infer_shape(data=(1,3,IN_W,IN_H))
    # aux_params = {}
    # for name,shape in zip(arg_names,arg_shapes):
    # 	print name+' : '+str(shape)

    model = mx.mod.Module(symbol=symbol, context=mx.cpu(), label_names=None)
    model.bind(for_training=False, data_shapes=diter._provide_data)

    sym, arg_params, aux_params = mx.model.load_checkpoint(MODEL_PREFIX, epoch)
    try:
        del arg_params['lossfactor']
    except:
        pass
    model.set_params(arg_params, aux_params, allow_missing=True)

    # arg_params, aux_params = {}, {}
    # arg_params_load = mx.nd.load(PARAMS_PATH)
    # for k in arg_params_load:
    #	if k in arg_names:
    #		arg_params[k] = arg_params_load[k]
    #	elif k in aux_names:
    #		aux_params[k] = arg_params_load[k]
    #	else:
    ##  		print 'invalid param %s with value %s...'%(k, arg_params_load[k][0:5])
    #		pass
    # model.set_params(arg_params, aux_params, allow_missing=True)

    fname_list = []
    for _, _, f in os.walk(img_path):
        fname_list.extend(f)
    random.shuffle(fname_list)
    for i in range(int(sys.argv[2])):
        try:
            img = Image.open(img_path + '/' + fname_list[i])
            w, h = img.size
            img = np.array(img)
            lbl = np.array(
                Image.open(LABEL_PATH + '/' + fname_list[i].split('.')[0] +
                           POST_FIX).resize((w, h)))
            img = np.swapaxes(img, 0, 2)
            model.forward(Batch([mx.nd.array(img).reshape((1, 3, w, h))]),
                          is_train=False)
            pred = model.get_outputs()[0].asnumpy()[0][0]
            pred = 254 * (pred - np.amin(pred)) / (np.amax(pred) -
                                                   np.amin(pred))
            Image.fromarray(pred.transpose().astype('uint8')).resize(
                (w, h)).convert('L').filter(ImageFilter.SMOOTH_MORE).filter(
                    ImageFilter.GaussianBlur(5)).save(
                        '../pred/' + fname_list[i].split('.')[0] + '.jpg')
            # Image.fromarray(pred.transpose().astype('uint8')).resize((w,h)).save('../pred/'+fname_list[i].split('.')[0]+'_pred.jpg')
            # Image.fromarray(np.swapaxes(img,0,2).astype('uint8')).resize((w,h)).save('../pred/'+fname_list[i])
            # Image.fromarray(lbl.astype('uint8')).resize((w,h)).save('../pred/'+fname_list[i].split('.')[0]+'_label.jpg')
            print '%sth img prediction finished with pixel range of %s~%s' % (
                i, np.amin(pred), np.amax(pred))
        except:
            traceback.print_exc()
            return 0