示例#1
0
def generate_token():
    """Generates JWT token for a user
    Returns: Token with expire datetime as JSON
    """
    auth = request.authorization

    if not auth or not auth.username or not auth.password:
        return {"message": "Invalid login credentials"}, 401
    user = user_collection.find_one({"username": auth.username})

    if not user:
        return {"message": "User does not exist"}, 404

    if check_password_hash(user["password"], auth.password):
        expire_datetime = datetime.datetime.utcnow() + datetime.timedelta(days=45)
        token = jwt.encode({"id": str(user["_id"]), "exp": expire_datetime}, SECRET_KEY)

        user_collection.update_one(
            {"username": auth.username}, {"$set": {"token_expire": expire_datetime}}
        )

        logger(
            msg=f"New token generated for user: {auth.username}",
            level="info",
            logfile="token",
        )
        return {"token": token.decode("UTF-8"), "expire": expire_datetime}

    return {"message": "Invalid login credentials"}, 401
示例#2
0
def delete_user(user: Dict):
    """Endpoint to delete a user
    Args:
        user: info of user calling the API
    """

    if not user["admin"]:
        return {"message": "Not permitted to perform this operation."}, 401

    data = request.json
    if type(data) != dict:
        return {"message": "Incorrect parameters"}, 400

    search_user = user_collection.find_one({"username": data["username"]})
    if not search_user:
        return {"message": "User not found"}, 404

    if search_user["admin"]:
        return {"message": "Not permitted to perform this operation."}, 401

    user_collection.delete_one({"username": data["username"]})

    logger(
        msg=f'User deleted by admin with username: {data["username"]}',
        level="info",
        logfile="token",
    )
    return {"message": "User deleted successfully."}
示例#3
0
 def run(self, port):
     logger().i("app runing.")
     logger().i("network run at %d", port)
     #t = threading.Thread(target=lambda p:self.network.run(p), args=(port,))
     #t.start()
     self.network.run(port)
     ioloop.IOLoop.instance().start()
示例#4
0
def create_user(user: Dict):
    """Endpoint to create new user
    Args:
        user: info of user calling the API
    """

    if not user["admin"]:
        return {"message": "Not permitted to perform this operation."}, 401

    data = request.json
    if type(data) != dict:
        return {"message": "Incorrect parameters"}, 400

    search_user = user_collection.find_one({"username": data["username"]})
    if search_user:
        return {"message": "User already exists"}, 409

    user_collection.insert_one(
        {
            "admin": data["admin"],
            "password": generate_password_hash(data["password"], method="sha256"),
            "username": data["username"],
        }
    )

    logger(
        msg=f'New user created by admin with username: {data["username"]}',
        level="info",
        logfile="token",
    )
    return {"message": "User created successfully."}
示例#5
0
    def request(cls, params):
        assert params['method'] in ['get', 'post', 'put', 'delete',
                                    'option']  # 请求方法限制
        # 请求接口
        logger().info(
            'starting request api: ' +
            cls.get_full_request_url(params['entry'], params['api_path']))
        request_start = time()  # 开始时间
        # holder
        response = getattr(requests, params['method'])(
            cls.get_full_request_url(params['entry'], params['api_path']),
            headers={
                'authorization': 'Bearer ' + cls.get_auth_token()
            },
            params=params['data'])
        request_stop = time()  # 结束时间
        assert isinstance(json.loads(response.text.encode('utf8')), dict)

        return {
            'request_params': params,  # 请求参数
            'result': {
                'wait_time': round((request_stop - request_start) * 1000,
                                   2),  # 请求用时
                'status_code': response.status_code,  # 返回状态码
                'content_length': len(response.text),
            },
            'content': json.loads(response.text.encode('utf8')),  # 返回原始内容
        }
示例#6
0
def preAnalysis(destination, events, images, run_volume, run_surface):
    try:
        print(f"Running Preanalysis on: {images[0]} and {images[1].encoding}")

        # If there is no INPUT_DATA folder in the subject create it
        if not os.path.exists(
                os.path.join(destination, images[0].subject, 'INPUT_DATA',
                             images[0].task, images[0].session)):
            os.makedirs(
                os.path.join(destination, images[0].subject, 'INPUT_DATA',
                             images[0].task, images[0].session))

        ##TODO For loop of images outside
        Copy_Input_Data.copy_input_data(images, destination, events)
        Format_Motion_Regressors.format_motion_regressors(
            destination, images)  # format the motion regressors
        Demean_Motion.demean_motion(destination, images)  # Demean Motion

        if run_volume:
            Demean_Images.volume_demean_images(
                destination, images)  # Demean the volume images
        if run_surface:
            for hemisphere in hemispheres:
                Demean_Images.surface_demean_images(
                    destination, hemisphere,
                    images)  # Demean the surface images
    except Exception as e:
        print(e)
        print(f"Error Running Preanalysis")
        logger.logger(e, 'error')
示例#7
0
 def on_receive(self, data):
     logger().i("receive from %s", str(self.address))
     msg = pb_helper.BytesToMessage(data)
     if isinstance(msg, message_common_pb2.DirInfo):
         self._send_dirinfo()
     else:
         self.close()
示例#8
0
def checkinput(images):
    all_good = True
    if not len(images) == 2:
        logger.logger(f"WARNING: Found {len(images)} images expected 2",
                      'warning')
        all_good = False
    for image in images:
        logger.logger(f"Found Image {image.file}", 'warning')
    return all_good
        def _find_keep_items(self):
            """If file_name is not None, find all items listed and place them
            at the end of the sorted list.

            And if client_name is not None, find client_name and place at
            the start of the sorted list.

            """

            # move client_name to start of list.

            # Note Matrix labels have been updated to include group name at
            # this point, so comparison is made to the split("::: ")
            if client_name is not None:
                if sort_row:
                    _keep_start = [
                        x for x in sorted_list if client_name == matrix[int(
                            x[0])].Member.Label.split("::: ")[1]
                    ]
                else:
                    _keep_start = [
                        x for x in sorted_list if client_name == matrix[0][int(
                            x[0])].TopMember.Label.split("::: ")[1]
                    ]
                if len(_keep_start) > 0:
                    sorted_list.remove(_keep_start[0])
                    sorted_list.insert(0, _keep_start[0])

            # move _keep_at_end  items to end of list.
            # Note Matrix labels have been updated to include group name at
            # this point, so comparison is made to the split("::: ")
            if file_name is not None:
                try:
                    # read the file_name file.
                    from utils.utilities import read_comma_separated_file
                    _keep_at_end = read_comma_separated_file(file_name)

                    if _keep_at_end is not None:
                        if sort_row:
                            _keep_end = [
                                x for x in sorted_list for item in _keep_at_end
                                if item == matrix[int(
                                    x[0])].Member.Label.split("::: ")[1]
                            ]
                        else:
                            _keep_end = [
                                x for x in sorted_list for item in _keep_at_end
                                if item == matrix[0][int(
                                    x[0])].TopMember.Label.split("::: ")[1]
                            ]

                        if len(_keep_end) > 0:
                            for item in _keep_end:
                                sorted_list.remove(item)
                                sorted_list.append(item)
                except:
                    logger("Unable to read _file_name: " + file_name)
示例#10
0
def checkroistats(roistats):
    filename = os.path.join(roistats.working_dir, roistats.outfile)
    all_good = True
    if os.path.exists(filename):
        if os.stat(filename).st_size == 0:
            logger.logger(f"WARNING: {filename} is empty!", 'warning')
            all_good = False
    else:
        logger.logger(f"WARNING: Could not find {filename}", 'warning')
        all_good = False
    return all_good
示例#11
0
def analysis(destination, images, run_volume, run_surface):
    GLM_set = []
    try:
        ##TODO Add the Preparcellated here
        GLM_set = get_GLMs(destination, images)

        if run_volume:
            run_volume_glms(GLM_set)
        if run_surface:
            run_surface_glms(GLM_set)
    except Exception as e:
        print(e)
        print("Error Running Analysis")
        logger.logger(e, 'error')
    return GLM_set
示例#12
0
def search():
    """
    Returns related documents according to the query from the database
    """

    content = request.json

    if type(content) != dict:
        return {"message": "Incorrect Parameters"}, 400
    text = content["text"]

    logger(msg=f"Search request for query: {text}", level="info", logfile="info")
    top_related_searches = db_search(text)

    return {"top_related_searches": top_related_searches}, 200
示例#13
0
 def _send_dirinfo(self):
     path = os.path.split(os.path.realpath(__file__))[0].replace("\\", "/")
     msg = message_common_pb2.DirInfo()
     filename = path + "/bin/version.xml"
     with open(filename,"rb") as fin:
         data = fin.read()
         msg.version = data
         fin.close()
     filename = path + "/bin/version.txt"
     with open(filename, "rb") as fin:
         data = fin.read()
         msg.patches = data
         fin.close()
     buff = pb_helper.MessageToSendBytes(msg)
     logger().i("send dirinfo to %s", str(self.address))
     self.send(buff)
示例#14
0
 def __init__(self):
     self.logger = logger()
     self.total_vps_count = 100
     self.proxy_dev_api = 'http://dev.task.hxsv.data.caasdata.com/?action=proxy_ips'
     self.vps_os_info_dict = dict()
     self.vps_static_status = ['vps_error', 'vps_success', 'os_success', 'os_fail', 'proxy_success', 'fail_ssh',
                               'fail_ping']
     self.vps_status_dict = self.gen_vps_status_dict(self.vps_static_status)
示例#15
0
 def __init__(self, host, port, user, passwd):
     # 建立数据库连接
     self.conn = pymysql.connect(host=host,
                                 port=port,
                                 user=user,
                                 passwd=passwd,
                                 charset='utf8')
     # 通过 cursor() 创建游标对象,并让查询结果以字典格式输出
     self.cur = self.conn.cursor(cursor=pymysql.cursors.DictCursor)
     self.logger = logger()
示例#16
0
 def __init__(self) -> None:
     self.setup_config()
     self.command_manager: object = command_manager(self)
     self.command_interface: object = command_interface(self)
     self.event_manager: object = event_manager(self)
     self.rak_net_interface: object = rak_net_interface(self)
     self.logger: object = logger()
     self.plugin_manager: object = plugin_manager(self)
     self.players: dict = {}
     self.current_entity_id: int = 1
     self.start()
示例#17
0
    def __createSpark(self, tweepy):
        import sys
        sys.path.append('/Users/melihozkan/Desktop/Projects/BitirmeProjesi/')
        from utils.logger import logger
        log = logger()
        log.createLog(str(self.sentimentid))
        spark: sparkServerModel = sparkServerModel.objects.get(serverid=1)
        pc = "python /Users/melihozkan/Desktop/Projects/BitirmeProjesi/sparkManager.py --host {} --port {} --sentimentId {}  --master {} --method {}"
        cmd = (pc.format(tweepy.address, tweepy.port, self.sentimentid,
                         spark.generate_address(), self.method))
        self.eprint(cmd)
        FNULL = open(os.devnull, 'w')

        DETACHED_PROCESS = 0x00000008
        sub = subprocess.Popen(shlex.split(cmd), stderr=FNULL, stdout=FNULL)
        log.log("info", "Subprocess Created")
        return sub.pid
示例#18
0
def check_evts(input, image):
    pattern=os.path.join(input, f"{image.subject}_{image.task}_{image.session}*txt")
    files = glob.glob(pattern)
    logger.logger("Moving evts", "info")
    good_files = []
    for file in files:
        logger.logger(f"Checking evt: {file}", "info")
        with open(os.path.join(input, file), "r") as a_file:
            for line in a_file:
                if len(line.strip().strip('*')) == 0:
                    print('found empty line in file ' + str(file))
                    logger.logger(f'Found blank evt {str(file)}', 'warning')
                    good_files.append(False)
                else:
                    good_files.append(True)
                if not any(good_files):
                    print("All of the evts had issues")
                    logger.logger(f'ERROR: All evts were blank', 'error')
                    raise NameError(f'ERROR: All evts were blank for {image.subject} {image.session} {image.task}')
示例#19
0
    def __init__(self, args):
        self.args = args

        self.train_loader, self.valid_loader = construct_loader(args.train_path, \
            args.valid_path, args.batch_size, args.dataset, args.cuda)

        # Define Optimizer,model
        if(args.model == 'padding_vectornet'):
            model = padding_VectorNet(args.depth_sub, args.width_sub, args.depth_global, args.width_global)
            train_params = [{'params': model.parameters(), 'lr': args.lr}]
        elif(args.model == 'vectornet'):
            model = VectorNet(args.depth_sub, args.width_sub, args.depth_global, args.width_global)
            train_params = [{'params': model.parameters(), 'lr': args.lr}]
        else:
            assert False, 'Error!!\nUnsupported model: {}'.format(args.model)

        self.model = model

        # CUDA enabled
        if(args.cuda):
            self.model = self.model.cuda()

        self.optimizer = torch.optim.Adam(train_params)
        self.criterion = loss_collection(args.cuda).construct_loss(args.loss_mode)

        # loss weight selection
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                    args.epochs, len(self.train_loader))

        self.metricer = metricer()

        if(not os.path.exists('ckpt/{}'.format(args.model))):
            os.makedirs('ckpt/{}'.format(args.model), 0o777)
        self.logger = logger('ckpt/{}'.format(args.model), ['DE@1s', 'DE@2s', 'DE@3s', 'ADE', 'loss'])
        if(not os.path.exists('ckpt/{}/storage'.format(args.model))):
            os.makedirs('ckpt/{}/storage'.format(args.model), 0o777)
        self.saver = saver('ckpt/{}/storage'.format(args.model), args.model)
        ret = self.saver.restore()
        self.start_epoch = 1
        self.best_pred = 0
        if(ret != None):
            self.model.load_state_dict(ret[0])
            self.optimizer.load_state_dict(ret[1])
            self.start_epoch = ret[2]
            self.best_pred = ret[3]
示例#20
0
 def __init__(self, hostname: str, port: int, sentimentId: str,
              master: str):
     self.logger = logger()
     self.logger.setSentiemtnId(sentimentId)
     self.__hostname = hostname
     self.__port = port
     self.__appName = "sentiment_" + sentimentId
     self.logger.log("info", "Initializing Spark Instance")
     conf = SparkConf()
     conf.setAppName(self.__appName)
     conf.setMaster(master)
     conf.set("spark.executor.memory", "4G")
     conf.set("spark.driver.memory", "4G")
     conf.set("spark.network.timeout", "600s")
     """conf.set("spark.cassandra.connection.host","134.122.166.110")
     conf.set("spark.cassandra.connection.port","9042")"""
     self.__sc: SparkContext = SparkContext.getOrCreate(conf)
     self.__sc.setLogLevel("ERROR")
     self.__ssc = StreamingContext(self.__sc, batchDuration=10)
     self.__spark = SQLContext(self.__sc)
     self.__dataStream = self.__ssc.socketTextStream(
         hostname=self.__hostname, port=self.__port)
     self.logger.log("info", "Spark Inıtıalized")
示例#21
0
	def clone_report(modeladmin, request, queryset):
		for q in queryset:
			logger( "clone_report {}".format(q) )
			# find yesterday report
			qset_daily_report_item = DailyReportItem.objects.filter( daily_report=q )
			daily_report_latest = q #DailyReport.objects.order_by('-report_date')[0]
			# update report_date
			#daily_report_latest.report_date = timezone.make_aware( datetime.now() ) 
			daily_report_latest.report_date = q.report_date
			# pk=None
			daily_report_latest.title = daily_report_latest.title + "_clone"
			daily_report_latest.pk = None
			# save()
			daily_report_latest.save()
			for q_daily_report_item in qset_daily_report_item:
				#q_daily_report_item.report_date  = timezone.make_aware( datetime.now() )
				q_daily_report_item.report_date = timezone.localtime()
				q_daily_report_item.daily_report = daily_report_latest
				q_daily_report_item.pk = None
				q_daily_report_item.save()
			# find Photos that require followup
			#photo_admin = PhotoAdmin(request,queryset)
			#photo_admin.autofill_related_daily_report_item()
			today = timezone.localtime()
			qset_photo = Photo.objects.filter( follow_up_date_end__gt = today )
			logger( "clone_report qset_photo len={}".format( len(qset_photo) ) )
			for q_photo in qset_photo:
				q_daily_report_item =	q_photo.get_related_daily_report_item()
				if q_daily_report_item != None:
					q_photo.daily_report_item.add( q_daily_report_item )
					q_photo.save()
			# Extent follow up by +1 date
			qset_photo_followup = Photo.objects.filter(
				follow_up_date_end__gt = timezone.localtime() )
			logger( '{} photos to followup, +1 date'.format( 
				len( qset_photo_followup) ) )
			qset_photo_followup.update( follow_up_date_end = \
				timezone.localtime() + timezone.timedelta(1,0,0) )
		msg = ungettext(
				"Successfully cloned report",
				"Successfully cloned reports",
				len( queryset )
			)
		messages.success( request, msg )
示例#22
0
import sendrequest as req
import utils.logs as logs
import os
import urlparse

from itertools import islice
from utils.logger import logger
from utils.db import Database_update
from utils.config import get_value

dbupdate = Database_update()
api_logger = logger()

redirection_url = "www.google.com"


def fetch_open_redirect_payload():
    # Returns open redirect payloads in list type
    payload_list = []
    if os.getcwd().split('/')[-1] == 'API':
        path = '../Payloads/redirect.txt'
    else:
        path = 'Payloads/redirect.txt'

    with open(path) as f:
        for line in islice(f, 1, None):
            if line:
                payload_list.append(line.rstrip())

    return payload_list
示例#23
0
import unittest

from page.page_login import page_login
from utils.logger import logger
from ddt import ddt, data, unpack

from utils.common import get_parent_path
import os
from utils.readcsv import get_csv_data
logger = logger(logger="test_login").getlog()


@ddt
class test_login(unittest.TestCase):
    test_account = (("admin", "123456", "admin"), ("admin", "123456", "admin"),
                    ("admin", "123456", "admin"))
    current_path = os.path.dirname(__file__)
    logger.info("__file__ current_path:%s" % current_path)
    test_account = get_csv_data(
        os.path.join(get_parent_path(current_path), "data", "user.csv"))
    logger.info("test_account:%s" % test_account)
    test_account_fail = get_csv_data(
        os.path.join(get_parent_path(current_path), "data", "user_fail.csv"))

    @classmethod
    def setUpClass(cls):
        cls.p_login = page_login()

    @data(*test_account)
    @unpack
    def test_login_pass(self, name, password, expect_reult):
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
np.random.seed(0)

exp_config = os.path.join(".", "config", args.exp + ".py")
exp_dir = os.path.join("../data/jingwei", args.exp)
exp_log_dir = os.path.join(exp_dir, "log")
if not os.path.exists(exp_log_dir):
    os.makedirs(exp_log_dir)
#读取参数
config = imp.load_source("", exp_config).config
#tensorboard && logger
now_str = datetime.datetime.now().__str__().replace(' ', '_')

logger_path = os.path.join(exp_log_dir, now_str + ".log")
logger = logger(logger_path).get_logger()

os.environ["CUDA_VISIBLE_DEVICES"] = '1, 2, 3'

train_config = config['train_config']

logger.info('preparing data......')

train_dataset = jingwei_train_dataset(csv_root=train_config['csv_root'], )
trainloader = DataLoader(dataset=train_dataset,
                         batch_size=train_config['batch_size'],
                         shuffle=True,
                         num_workers=args.num_workers,
                         drop_last=True)
logger.info('data done!')
示例#25
0
def outputVerifier(images, GLM_set):
    try:
        logger.logger(f"Verifiying input", 'info')
        input_good = checkinput(images)
    except:
        logger.logger(
            f"Error while checking input images make sure they exist", 'error')
    try:
        logger.logger(f"Verifiying output", 'info')
        output_good = checkoutput(GLM_set)
    except:
        logger.logger(
            f"Error while checking output files make sure they exist", 'error')
    if input_good and output_good:
        logger.logger(f"Both input and output look good", 'info')
    else:
        logger.logger(f"There was an with either the input or the output",
                      'error')
示例#26
0
# -*- coding: utf-8 -*-

import os
import sys
import socket

from tornado import web, ioloop, httpserver, process, netutil

from router import ROUTERS
from conf import SETTINGS, DATABASE

from utils.logger import logger

log = logger('admin')


class Application(web.Application):
    def __init__(self):
        super(Application, self).__init__(ROUTERS, **SETTINGS)


if __name__ == '__main__':
    args = sys.argv[1:]
    if args[0] == 'run':
        app = Application()
        print('Starting server on port 9000...')
        # sockets = netutil.bind_sockets(9000, '127.0.0.1', socket.AF_UNSPEC)
        # process.fork_processes(5)
        server = httpserver.HTTPServer(app)
        server.listen(9000)
        # server.start(num_processes=4)
示例#27
0
with open('commands.py', 'r') as f:
    for line in f:
        commands_list.append(line.split(None, 1)[0])
        t += 1

a = 0
while commands_total != a:
    line = file1.readline()
    if not line:
        break
    if not line.find("#"):
        a += 1
    else:
        file = logger.log_file()
        logger.logger(
            logger.debugger_args("main", "commands_total",
                                 "command=" + commands_list[a]))
        print("Command " + commands_list[a] + " loaded.")
        a += 1
    if commands_total == a:
        time.sleep(0.7)
        file = logger.log_file()
        logger.logger(logger.debugger_args("main", "total", ""))
        os.write(
            file,
            str.encode("[%s] (STARTER): All commands loaded (%s)" %
                       (str(datetime.now().strftime("%H:%M:%S")), str(a))))
        print("All commands loaded (%s)" % (str(a)))
    time.sleep(0.3)

for event in lp.listen():
示例#28
0
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import sys
sys.path.append('/zywa/aoam')
import re

from utils.logger import logger
from utils import config_util, exeCmd
from utils.environment_util import environment_util

env = environment_util()
conf = config_util.getDict('flume')
log = logger(loggername='flume')
'''
检测索引是否存在,如果集群挂了,是无法返回索引的存在与否


def checkServer():
    hostAndPorts = conf.get('hosts')
    es = Elasticsearch(hostAndPorts)
    #print(es.indices.exists(index="test"))  #填写对应索引
    lists = es.cluster.health()
    logger.info("elasticsearch集群状态:",end="")
    ES_cluster_status = lists["status"]
    if ES_cluster_status == "green":
        logger.info("####集群处于健康状态####")
    elif ES_cluster_status == "yellow":
        logger.info("集群处于亚健康状态")
    elif ES_cluster_status == "red":
        logger.warn("集群挂了")
    logger.info("elasticsearch集群节点数:"+lists['number_of_nodes'])
示例#29
0
def pruning_naive(task_name, path_model, percent, plan_pruning, batch_size, entropy_factor, cnt_train):
    # load model
    path_cfg = '/'.join(path_model.split('/')[:-1]) + '/cfg.ini'
    if cnt_train:
        cfg = load_cfg(path_cfg, localize=False, task_name=task_name)
    else:
        cfg = get_cfg(dataset_name=task_name, file_cfg_model='net_lstm.cfg')

    weights_dict = pickle.load(open(path_model, 'rb'))

    n_lstm = cfg['structure'].getint('n_layers_lstm')

    mask_last = [True for _ in range(cfg['data'].getint('dimension'))]
    # lstm
    for layer_index in range(n_lstm):
        name_prefix = 'rnn/multi_rnn_cell/cell_%d/lstm_cell' % layer_index

        mask = get_mask_lstm(weights_dict[name_prefix + '/weights'], percent)

        mask_kernel = np.concatenate([mask, mask, mask, mask]).astype(bool)
        mask_input = np.concatenate([mask_last, mask]).astype(bool)

        weights_dict[name_prefix + '/weights'] = weights_dict[name_prefix + '/weights'][mask_input, :][:, mask_kernel]
        weights_dict[name_prefix + '/biases'] = weights_dict[name_prefix + '/biases'][mask_kernel]

        mask_last = mask
    # output layer
    layer_name = 'fc5'
    weights_dict[layer_name + '/weights'] = weights_dict[layer_name + '/weights'][mask_last, :]

    if cnt_train == 0:
        time_stamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')

        cfg['basic']['task_name'] = task_name
        cfg['basic']['time_stamp'] = time_stamp
        cfg['basic']['batch_size'] = str(batch_size)
        cfg['basic']['pruning_method'] = 'naive_pruning'
        cfg.add_section('pruning')
        cfg.set('pruning', 'percent', str(percent))

        cfg['path']['path_save'] += '%s-%s' % (task_name, time_stamp)
        cfg['path']['path_cfg'] = cfg['path']['path_save'] + '/cfg.ini'
        cfg['path']['path_log'] = cfg['path']['path_save'] + '/log.log'
        cfg['path']['path_dataset'] = cfg['path']['path_dataset'] + cfg['basic']['task_name'] + '/'

    cfg.set('basic', 'entropy_factor', str(entropy_factor))
    logger(cfg['path']['path_log'])

    # save dir
    if not os.path.exists(cfg['path']['path_save']):
        os.mkdir(cfg['path']['path_save'])

    # save model
    pickle.dump(weights_dict, open(cfg['path']['path_save'] + '/test.pickle', 'wb'))
    cfg['path']['path_load'] = str(cfg['path']['path_save'] + '/test.pickle')

    gpu_config = tf.ConfigProto(log_device_placement=False, allow_soft_placement=True)

    tf.reset_default_graph()
    sess = tf.Session(config=gpu_config)

    model = LstmNet(cfg)
    model.cnt_train = cnt_train

    sess.run(tf.global_variables_initializer())

    log_l('Pre test')
    model.eval_once(sess, epoch=0)
    log_l('')

    model.save_cfg()
    for plan in plan_pruning:
        name = model.train(sess=sess, n_epochs=plan['n_epochs'], lr=plan['lr'], save=plan['save'])
        model.save_cfg()

    return name, model.cnt_train + 1
示例#30
0
#!/usr/bin/python3
# -*- encoding:utf-8 -*-
import sys

sys.path.append('/zywa/aoam')
import subprocess as sbp
from utils.logger import logger

log = logger()


def check_call(cmd):
    code = 0
    try:
        sbp.check_call(cmd)
    except:
        print("运行: " + cmd + "命令失败")
        code = 1
    return code


'''
获取执行结果
'''


def getoutput(cmd):
    return sbp.getoutput(cmd)


'''