示例#1
0
    def _init_train_ops(self):
        self.n_step = self.args.n_step
        self.log_iters = self.args.log_iters
        self.learning_rate = self.args.lr
        self.gamma = self.args.gamma
        self.tau = self.args.tau
        self.vf_coef = self.args.vf_coef
        self.ent_coef = self.args.ent_coef
        self.threads = self.env.num_envs
        self.n_batch = self.n_step * self.threads
        self.train_iters = (self.args.timesteps // self.n_batch) + 1
        self.max_grad_norm = 0.5
        self.outdir = self.args.outdir
        if not os.path.exists(self.outdir):
            os.mkdir(self.outdir)

        if hasattr(self.args, 'game_lives'):
            self.rew_tracker = RewardTracker(self.threads,
                                             self.args.game_lives)
        else:
            self.rew_tracker = RewardTracker(self.threads, 1)
        self.logger = Logger(self.args.outdir)
        self.optimizer = optim.Adam(self.policy.parameters(),
                                    lr=self.learning_rate)

        # TODO: may not be the best way to do this
        headers = [
            "timestep", 'mean_rew', "best_mean_rew", "episodes", "policy_loss",
            "value_loss", "entropy", "time_elapsed"
        ]
        self.logger.set_headers(headers)
示例#2
0
 def run(self):
     """进程执行方法"""
     self._log = Logger(PRJ_PATH)
     while True:
         target, type, msg = self._queue.get()
         if target == 'NOTICE':
             self._notice(type, msg)
         elif target == 'LOG':
             self._logger(type, msg)
示例#3
0
    def __init__(self, brokers, currency):
        self.brokers = brokers
        self.currency = currency
        self.log = Logger()
        """
        first key = bidder exchange
        second key = asker exchange
        spread['CRYPTSY']['COINSE'] = -0.1 # raw profits (bid - ask)
        """
        self.prices = {}        # maintains hi_bids and lo_asks for each broker
        self.balances = {}      # base and alt balances for each exchange
        self.profit_spread = {} # price spreads with transaction fees applied
        self.profits = {}       # actual ALT profits, accounting for balances and volumes

        # self.update_balances()
        self.update_profit_spread() # automatically perform calculations upon initialization
示例#4
0
文件: Bot.py 项目: enplus/parttime
 def __init__(self, config, brokers):
     """
     config = configuration file
     brokers = array of broker objects
     """
     super(Bot, self).__init__()
     self.config = config
     self.brokers = brokers
     self.error = False
     self.log = Logger()
     self.backtest_data = None
     self.max_ticks = 0
     self.data_path = abspath(config.TICK_DIR)
     self.trading_enabled = True
     self.tick_i = 0
     self.debug = config.DEBUG
     self.debug_cls = config.DEBUG_CLEARSCREEN
     self.debug_mktdata = config.DEBUG_MARKETDATA
     self.debug_spread = config.DEBUG_SPREAD
示例#5
0
def main():
    global net
    global renderers
    global rendererIndex
    global vpns
    global vpnIndex
    if net:
        return

    renderers = []
    rendererIndex = {}
    vpns = []  # used in vpn.py only
    vpnIndex = {}  # used in vpn.py only

    random.seed(0)  # in order to get the same result to simplify debugging
    InitLogger()
    configFileName = None
    if len(sys.argv) > 1:
        configFileName = sys.argv[1]
        net = TestbedTopology(fileName=configFileName)
    else:
        net = TestbedTopology()
    # One-time setup for the VPN service
    wi = WanIntent("esnet", net.builder.wan)
    wr = WanRenderer(wi)
    wr.execute()
    renderers.append(wr)
    rendererIndex[wr.name] = wr

    for site in net.builder.sites:
        intent = SiteIntent(name=site.name, site=site)
        sr = SiteRenderer(intent)
        suc = sr.execute()  # no function without vpn information
        if not suc:
            Logger().warning('%r.execute() fail', sr)
        renderers.append(sr)
        rendererIndex[sr.name] = sr
    print "Now the demo environment is ready."
示例#6
0
class Result:
    logger = Logger('result')

    def __init__(self, msg: str, pass_: bool, rule: str, index: int, task_id: int,
                 task_name: str, mission_name: str):

        self.index = index
        self.task_id = task_id
        self.task_name = task_name
        self.mission_name = mission_name

        self.pass_ = pass_
        self.msg = msg
        self.rule = rule

    def __bool__(self):
        return self.pass_

    def __str__(self):
        result = '\n'.join((self.rule, self.msg, '<' + '='*50 + '>'))
        return result

    @classmethod
    def create_by_task_engine(cls, task_engine, pass_):
        cls.logger.info('try to create by task engine')
        task = task_engine.task
        if pass_:
            msg = 'PASS ' + task.msg
        else:
            msg = 'FAIL ' + task.msg
        return cls(msg, pass_, task.RULE, task_engine.task_index, task_engine.task_id,
                   task_engine.task_name, task_engine.mission_name)

    @property
    def short_str(self):
        return self.msg
示例#7
0
from celery.schedules import schedule, crontab
from collections import OrderedDict
from datetime import datetime
import time

from common.utils import Logger
from configs import TIMEZONE

logger = Logger('TIME HELPER')


def can_run_1min():
    logger.info('1 min always return True')
    return True


def can_run_5min():
    logger.info('to konw 5 min if can run or not')
    return datetime.now(tz=TIMEZONE).minute // 5 == 0


def can_run_15min():
    logger.info('to konw 15 min if can run or not')
    return datetime.now(tz=TIMEZONE).minute // 15 == 0


def can_run_30min():
    logger.info('to konw 30 min if can run or not')
    return datetime.now(tz=TIMEZONE).minute // 30 == 0

示例#8
0
class SiteRenderer(ProvisioningRenderer, ScopeOwner):
    """
    Implements the rendering of provisioning intents on the Site. This class is responsible for pushing the proper
    flowMods that will forward packets between the hosts and the ESnet border router. Typically the topology is

         host(s) <-> siteRouter <-> borderRouter

         Simple vlan/port mach and outport /vlan on siteRouter needs to be set
    """
    debug = False
    lastEvent = None
    logger = Logger('SiteRenderer')

    def __init__(self, intent):
        """
        Generic constructor. Translate the intent
        :param intent: SiteIntent
        :return:
        """
        ScopeOwner.__init__(self, name=intent.name)
        self.intent = intent
        self.site = self.intent.site
        self.siteRouter = self.intent.siteRouter
        self.borderRouter = self.intent.borderRouter
        graph = intent.graph
        self.macs = {}
        self.active = False
        self.activePorts = {}  # [portname] = TestbedPort
        self.lock = threading.Lock()
        self.props['lanVlanIndex'] = {}  # [siteVlan] = lanVlan
        self.props['siteVlanIndex'] = {}  # [lanVlan] = siteVlan
        self.props['portsIndex'] = {
        }  # [lanVlan] = list of TestbedPort that allows lanVlan to pass
        self.props['portIndex'] = {
        }  # [str(mac)] = TestbedPort that links to the MAC
        self.props['scopeIndex'] = {}  # [switch.name] = L2SwitchScope
        self.props[
            'borderToSitePort'] = None  # the port linking siteRouter and borderRouter
        self.props[
            'borderToSDNPort'] = None  # the port linking borderRouter and hwSwitch

        # cheating by awaring all the hosts to avoid any possible missed packet
        for host in self.site.props['hosts']:
            port = host.props['ports'][1].props['links'][0].props['portIndex'][
                self.siteRouter.name]
            self.props['portIndex'][str(
                host.props['mac'])] = port.props['enosPort']

        # Create scope for the site router
        siteScope = L2SwitchScope(name=intent.name,
                                  switch=self.siteRouter,
                                  owner=self)
        siteScope.props['intent'] = self.intent
        self.props['scopeIndex'][self.siteRouter.name] = siteScope
        for port in self.siteRouter.getPorts():
            self.activePorts[port.name] = port
            port.props['scope'] = siteScope
        siteScope.addEndpoint(self.siteRouter.props['toWanPort'])
        self.siteRouter.props['controller'].addScope(siteScope)
        # Create scope for the border router
        wanScope = L2SwitchScope(name="%s.wan" % intent.name,
                                 switch=self.borderRouter,
                                 owner=self)
        wanScope.props['intent'] = self.intent
        self.props['scopeIndex'][self.borderRouter.name] = wanScope
        borderToSitePort = self.borderRouter.props['sitePortIndex'][
            self.site.name].props['enosPort']
        self.activePorts[borderToSitePort.name] = borderToSitePort
        borderToSitePort.props['scope'] = wanScope
        wanScope.addEndpoint(borderToSitePort)
        borderToSDNPort = self.borderRouter.props['stitchedPortIndex'][
            borderToSitePort.name]
        self.activePorts[borderToSDNPort.name] = borderToSDNPort
        borderToSDNPort.props['scope'] = wanScope
        """
        The reason we comment out this line is:
        Since we'd like to support multiple sites, borderToSDNPort could be shared.
        Therefore, we can not occupy it alone.
        The scope should be added later while 'vpn addsite'.
        """
        # wanScope.addEndpoint(borderToSDNPort)
        self.props['borderToSitePort'] = borderToSitePort
        self.props['borderToSDNPort'] = borderToSDNPort
        self.borderRouter.props['controller'].addScope(wanScope)

    def checkVlan(self, lanVlan, siteVlan):
        # could be invoked in CLI
        # if lanVlan or siteVlan exists already, they must exactly identical to original values
        if (lanVlan in self.props['siteVlanIndex']
                and self.props['siteVlanIndex'][lanVlan] != siteVlan) or (
                    siteVlan in self.props['lanVlanIndex']
                    and self.props['lanVlanIndex'][siteVlan] != lanVlan):
            SiteRenderer.logger.warning(
                "different lanVlan and siteVlan is not allowed")
            return False
        return True

    def addVlan(self, lanVlan, siteVlan):
        # could be invoked in CLI
        if not self.checkVlan(lanVlan, siteVlan):
            return
        self.props['lanVlanIndex'][siteVlan] = lanVlan
        self.props['siteVlanIndex'][lanVlan] = siteVlan
        self.stitch(siteVlan)

    def delVlan(self, lanVlan, siteVlan):
        # could be invoked in CLI
        if not self.checkVlan(lanVlan, siteVlan):
            return
        self.cut(siteVlan)
        self.props['lanVlanIndex'].pop(siteVlan)
        self.props['siteVlanIndex'].pop(lanVlan)
        siteScope = self.props['scopeIndex'][self.siteRouter.name]
        for (key, flowmod) in siteScope.props['flowmodIndex'].items():
            found = False
            if not found:
                found = flowmod.match.props['dl_dst'].isBroadcast()
            if not found:
                if flowmod.match.props['in_port'].props[
                        'type'] == 'SiteToHost':
                    found = flowmod.match.props['vlan'] == lanVlan
                else:
                    found = flowmod.match.props['vlan'] == siteVlan
            if not found:
                for action in flowmod.actions:
                    if action.props['out_port'].props['type'] == 'SiteToHost':
                        if action.props['vlan'] == lanVlan:
                            found = True
                            break
                    else:
                        if action.props['vlan'] == siteVlan:
                            found = True
                            break
            if not found:
                # here we try to keep some flowmods that are not related to the site at all
                continue
            siteScope.delFlowMod(flowmod)
        for port in self.siteRouter.props['ports'].values():
            if port.props['type'] == 'SiteToHost':
                siteScope.delEndpoint(port, lanVlan)
            else:
                siteScope.delEndpoint(port, siteVlan)
        if lanVlan in self.props['portsIndex']:
            self.props['portsIndex'].pop(lanVlan)

    def addHost(self, host, lanVlan):
        # could be invoked in CLI
        with self.lock:
            if not lanVlan in self.props['siteVlanIndex']:
                SiteRenderer.logger.warning("lanVlan %d not found" % lanVlan)
                return

            toHostPort = self.siteRouter.props['hostPortIndex'][
                host.name].props['enosPort']

            if not lanVlan in self.props['portsIndex']:
                self.props['portsIndex'][lanVlan] = []
            self.props['portsIndex'][lanVlan].append(toHostPort)

            siteScope = self.props['scopeIndex'][self.siteRouter.name]
            siteScope.addEndpoint(toHostPort, lanVlan)

    def delHost(self, host, lanVlan):
        # could be invoked in CLI
        with self.lock:
            if not lanVlan in self.props['siteVlanIndex']:
                SiteRenderer.logger.warning("lanVlan %d not found" % lanVlan)
                return

            toHostPort = self.siteRouter.props['hostPortIndex'][
                host.name].props['enosPort']
            if not toHostPort in self.props['portsIndex'][lanVlan]:
                SiteRenderer.logger.warning("host %s not exists" % host.name)
                return
            self.props['portsIndex'][lanVlan].remove(toHostPort)

            siteScope = self.props['scopeIndex'][self.siteRouter.name]
            for (key, flowmod) in siteScope.props['flowmodIndex'].items():
                found = False
                if not found:
                    found = flowmod.match.props[
                        'in_port'].name == toHostPort.name and flowmod.match.props[
                            'vlan'] == lanVlan
                if not found:
                    for action in flowmod.actions:
                        if action.props[
                                'out_port'].name == toHostPort.name and action.props[
                                    'vlan'] == lanVlan:
                            found = True
                            break
                if not found:
                    # here we try to keep some flowmods that are not related to the host at all
                    continue
                """
                A possible improvement here might be:
                Try to modify the broadcast flowmod instead of deleting it.
                However, neither site router nor broadcast is our concern in
                the demo, so I just delete it directly.
                """
                siteScope.delFlowMod(flowmod)
            siteScope.delEndpoint(toHostPort, lanVlan)

    def __str__(self):
        return "SiteRenderer(name=%s, activePorts=%r, scopeIndex=%r)" % (
            self.name, self.activePorts, self.props['scopeIndex'])

    def __repr__(self):
        return self.__str__()

    def eventListener(self, event):
        """
        The implementation of this class is expected to overwrite this method if it desires
        to receive events from the controller such as PACKET_IN
        :param event: ScopeEvent
        """
        if event.__class__ != PacketInEvent:
            SiteRenderer.logger.warning("%s is not a PACKET_IN." % event)
        with self.lock:
            if not 'vlan' in event.props:
                # no VLAN, reject
                SiteRenderer.logger.debug("no VLAN, reject %r" % event)
                return
            SiteRenderer.logger.debug("eventListener: %r" % event)
            if SiteRenderer.debug:
                SiteRenderer.lastEvent = event

            inPort = event.props['in_port'].props['enosPort']
            srcMac = event.props['dl_src']
            dstMac = event.props['dl_dst']
            inVlan = event.props['vlan']
            if inPort.props['type'] == 'SiteToHost':
                lanVlan = inVlan
                siteVlan = self.props['siteVlanIndex'][lanVlan]
            elif inPort.props['type'] == 'SiteToCore':
                siteVlan = inVlan
                lanVlan = self.props['lanVlanIndex'][siteVlan]
            else:
                SiteRenderer.logger.warning("Unknown event %r in %r" %
                                            (event, self))
                return
            # update information of which port the srcMac belongs
            if not str(srcMac) in self.props['portIndex']:
                self.props['portIndex'][str(srcMac)] = inPort

            etherType = event.props['ethertype']
            payload = event.props['payload']
            switch = inPort.props['node'].props['enosNode']
            switchController = switch.props['controller']
            scope = inPort.props['scope']
            if dstMac.isBroadcast():
                outputs = []
                if inPort.props['type'] == 'SiteToHost':
                    outputs.append(
                        (dstMac, siteVlan, switch.props['toWanPort']))
                for outPort in self.props['portsIndex'][lanVlan]:
                    if outPort.name == inPort.name:
                        continue
                    outputs.append((dstMac, lanVlan, outPort))
                scope.multicast(switch, dstMac, inVlan, inPort, outputs)
                for (outMac, outVlan, outPort) in outputs:
                    scope.send(switch, outPort, srcMac, outMac, etherType,
                               outVlan, payload)
            else:
                if not str(dstMac) in self.props['portIndex']:
                    SiteRenderer.logger.warning(
                        "Unknown destination (%r) on site %r" % (event, self))
                    # return
                    # hack that all unknown mac should go to wan
                    self.props['portIndex'][str(
                        dstMac)] = switch.props['toWanPort']
                outPort = self.props['portIndex'][str(dstMac)]
                if outPort.props['type'] == 'SiteToCore':
                    outVlan = siteVlan
                else:
                    outVlan = lanVlan
                scope.forward(switch, dstMac, inVlan, inPort, dstMac, outVlan,
                              outPort)
                scope.send(switch, outPort, srcMac, dstMac, etherType, outVlan,
                           payload)

    def stitch(self, siteVlan):
        siteScope = self.props['scopeIndex'][self.siteRouter.name]
        siteRouterToWanPort = self.siteRouter.props['toWanPort'].props[
            'enosPort']
        siteScope.addEndpoint(siteRouterToWanPort, siteVlan)

        inPort = self.props['borderToSitePort']
        outPort = self.props['borderToSDNPort']
        wanScope = self.props['scopeIndex'][self.borderRouter.name]
        wanScope.addEndpoint(inPort, siteVlan)
        wanScope.addEndpoint(outPort, siteVlan)
        success = True
        for (direction, port1, port2) in [('site_to_hw', inPort, outPort),
                                          ('hw_to_site', outPort, inPort)]:
            if not wanScope.forward(self.borderRouter, None, siteVlan, port1,
                                    None, siteVlan, port2):
                success = False
        return success

    def cut(self, siteVlan):
        siteScope = self.props['scopeIndex'][self.siteRouter.name]
        siteRouterToWanPort = self.siteRouter.props['toWanPort'].props[
            'enosPort']

        wanScope = self.props['scopeIndex'][self.borderRouter.name]

        inPort = self.props['borderToSitePort']
        outPort = self.props['borderToSDNPort']
        success = True
        for (direction, port1, port2) in [('site_to_hw', inPort, outPort),
                                          ('hw_to_site', outPort, inPort)]:
            if not wanScope.stopForward(self.borderRouter, None, siteVlan,
                                        port1, None, siteVlan, port2):
                success = False
        siteScope.delEndpoint(siteRouterToWanPort, siteVlan)
        wanScope.delEndpoint(inPort, siteVlan)
        wanScope.delEndpoint(outPort, siteVlan)
        return success

    def setBorderRouter(self):
        success = True
        for siteVlan in self.props['siteVlanIndex'].values():
            if not self.stitch(siteVlan):
                success = False
        return success

    def removeFlowEntries(self):
        return False

    def execute(self):
        """
        Note: this function is not useful so far, since all the rule is
        configured dynamically in runtime.

        Renders the intent.
        :return: Expectation when succcessful, None otherwise
        """
        # Request the scope to the controller
        self.active = True
        # set broadcast flow entry
        success = self.setBorderRouter()

        return success

    def destroy(self):
        """
        Destroys or stop the rendering of the intent.
        :return: True if successful
        """
        self.active = False
        return self.removeFlowEntries()
示例#9
0
from common.instance import celery
from common.time_helper import RUN_TIME_SCHEDULE
from common.utils import Logger
from missions.interface import run_mission, get_valid_mission_missionary

log = Logger('MY APS')


@celery.on_after_configure.connect
def setup_periodic_tasks(sender, **kwargs):
    # 每十分钟执行test('hello')
    for mission_dict in get_valid_mission_missionary():
        mission = mission_dict['mission']
        missionary = mission_dict['missionary']
        run_time = RUN_TIME_SCHEDULE[missionary.run_time]
        sender.add_periodic_task(run_time, run_mission.s(mission.id), name='mission-{}'.format(mission.id))
    log.info('after add all missionary, the mission id to entry: {}')
示例#10
0
import inspect
from common.email_helper import send_error
from common.utils import Logger

logger = Logger('task template')


class TaskTemplateBase:
    RULE = ''
    MSG_FORMAT = ''

    def __init__(self, *args, **kwargs):
        self.msg_args = tuple()
        self.need_record = True

    def __str__(self):
        return self.msg

    def try_pass(self, *args, **kwargs) -> bool:
        raise Exception('must be covered')

    @classmethod
    def get_args(cls):
        return str(inspect.getfullargspec(cls.try_pass).args[1:])
        # return str(cls.try_pass.__code__.co_varnames[1:])   # exclude self

    @property
    def msg(self):
        if len(self.msg_args) == 0:
            result = '[Task Temp未运行: {}]'.format(self.RULE)
        else:
def train(func_train_one_batch, param_dict, path, log_dir_path):
    dis_loss = []
    pos_gen_loss = []
    neg_gen_loss = []

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    p = Logger(log_dir_path, **param_dict)

    # load data base
    if p.dataset is 'car196':
        data = Car196(root=path)
    else:
        print('DATASET is', p.dataset)
        data = CUB_200_2011(root=path)

    sampler = BalancedBatchSampler(data.train.label_to_indices,
                                   n_samples=p.n_samples,
                                   n_classes=p.n_classes)
    kwargs = {'num_workers': 6, 'pin_memory': True}

    train_loader = DataLoader(data.train, batch_sampler=sampler,
                              **kwargs)  # (5 * 98, 3, 224, 224)

    # train_iter = iter(train_loader)
    # batch = next(train_iter)
    # generate_random_triplets_from_batch(batch, p.n_samples, p.n_classes)

    test_loader = DataLoader(data.test, batch_size=p.batch_size)

    # construct the model
    model = ModifiedGoogLeNet(p.out_dim, p.normalize_hidden).to(device)
    model_pos_gen = Generator(p.out_dim, p.normalize_hidden).to(device)
    model_neg_gen = Generator(p.out_dim, p.normalize_output).to(device)

    model_dis = Discriminator(p.out_dim, p.out_dim).to(device)

    model_optimizer = optim.Adam(model.parameters(), lr=p.learning_rate)
    pos_gen_optimizer = optim.Adam(model_pos_gen.parameters(),
                                   lr=p.learning_rate)
    neg_gen_optimizer = optim.Adam(model_neg_gen.parameters(),
                                   lr=p.learning_rate)
    dis_optimizer = optim.Adam(model_dis.parameters(), lr=p.learning_rate)
    model_feat_optimizer = optim.Adam(model.parameters(), lr=p.learning_rate)

    time_origin = time.time()
    best_nmi_1 = 0.
    best_f1_1 = 0.
    best_nmi_2 = 0.
    best_f1_2 = 0.

    for epoch in range(p.num_epochs):
        time_begin = time.time()
        epoch_loss_neg_gen = 0
        epoch_loss_pos_gen = 0
        epoch_loss_dis = 0
        total = 0
        for batch in tqdm(train_loader, desc='# {}'.format(epoch)):
            triplet_batch = generate_random_triplets_from_batch(
                batch, n_samples=p.n_samples, n_class=p.n_classes)
            loss_pos_gen, loss_neg_gen, loss_dis = func_train_one_batch(
                device, model, model_pos_gen, model_neg_gen, model_dis,
                model_optimizer, model_feat_optimizer, pos_gen_optimizer,
                neg_gen_optimizer, dis_optimizer, p, triplet_batch, epoch)
            '''
            loss_dis = func_train_one_batch(device, model, model_dis, model_pos_gen
                                            model_optimizer,
                                            dis_optimizer, p, triplet_batch)
            '''

            epoch_loss_neg_gen += loss_neg_gen
            epoch_loss_pos_gen += loss_pos_gen
            epoch_loss_dis += loss_dis
            total += triplet_batch[0].size(0)

        loss_average_neg_gen = epoch_loss_neg_gen / total
        loss_average_pos_gen = epoch_loss_pos_gen / total
        loss_average_dis = epoch_loss_dis / total

        dis_loss.append(loss_average_dis)
        pos_gen_loss.append(loss_average_pos_gen)
        neg_gen_loss.append(loss_average_neg_gen)

        nmi, f1 = evaluate(device,
                           model,
                           model_dis,
                           test_loader,
                           epoch,
                           n_classes=p.n_classes,
                           distance=p.distance_type,
                           normalize=p.normalize_output,
                           neg_gen_epoch=p.neg_gen_epoch)
        if nmi > best_nmi_1:
            best_nmi_1 = nmi
            best_f1_1 = f1
            torch.save(model, os.path.join(p.model_save_path, "model.pt"))
            torch.save(model_pos_gen,
                       os.path.join(p.model_save_path, "model_pos_gen.pt"))
            torch.save(model_neg_gen,
                       os.path.join(p.model_save_path, "model_neg_gen.pt"))
            torch.save(model_dis,
                       os.path.join(p.model_save_path, "model_dis.pt"))
        if f1 > best_f1_2:
            best_nmi_2 = nmi
            best_f1_2 = f1

        time_end = time.time()
        epoch_time = time_end - time_begin
        total_time = time_end - time_origin

        print("#", epoch)
        print("time: {} ({})".format(epoch_time, total_time))
        print("[train] loss NEG gen:", loss_average_neg_gen)
        print("[train] loss POS gen:", loss_average_pos_gen)
        print("[train] loss dis:", loss_average_dis)
        print("[test]  nmi:", nmi)
        print("[test]  f1:", f1)
        print("[test]  nmi:", best_nmi_1, "  f1:", best_f1_1, "for max nmi")
        print("[test]  nmi:", best_nmi_2, "  f1:", best_f1_2, "for max f1")
        print(p)

    plt.plot(dis_loss)
    plt.ylabel("dis_loss")
    plt.savefig(log_dir_path + '/dis_loss.png')
    plt.close()

    plt.plot(pos_gen_loss)
    plt.ylabel("pos_gen_loss")
    plt.savefig(log_dir_path + '/pos_gen_loss.png')
    plt.close()

    plt.plot(neg_gen_loss)
    plt.ylabel("neg_gen_loss")
    plt.savefig(log_dir_path + '/neg_gen_loss.png')
    plt.close()
示例#12
0
from celery.app.task import Task
from huobi.model.constant import OrderType

from common.instance import celery
from common.utils import Logger
from trade.huobi_client import get_order, market_buy, market_sell, get_all_spot_balance
from trade.model import Trade


logger = Logger('trade interface')


@celery.task(bind=True, name='记录订单', max_retries=20, default_retry_delay=30)
def record_by_order_id(self: Task, order_id: int, mission_id: int, mission_name: str, missionary_id: int):
    try:
        if order_id == 0:           # for test
            return
        order = get_order('thetausdt', order_id)
        price = order.filled_cash_amount / order.filled_amount
        if price == 0 or price is None:
            raise ValueError('market buy, the price should not be 0 or None')
        # if order.order_type == OrderType.SELL_MARKET:            # 卖出
        #     price = order.filled_cash_amount / order.filled_amount
        # elif order.order_type == OrderType.BUY_MARKET:           # 买入
        #     if order.price == 0 or order.price is None:
        #         raise ValueError('market buy, the price should not be 0 or None')
        #     price = order.price
        # else:
        #     raise TypeError('unexpect order type')

        return Trade.create(amount=order.amount, symbol=order.symbol,
示例#13
0
    def run(self, ):
        obs_dim = self.env.observation_space.shape[0]
        act_dim = self.env.action_space.shape[0]
        obs_dim += 1  # add 1 to obs dimension for time step feature (see run_episode())
        now = datetime.utcnow().strftime(
            "%b-%d_%H:%M:%S")  # create unique directories
        logger = Logger(logname=self.env_name, now=now)
        scaler = Scaler(obs_dim)

        # initialize tensorflow session
        tf.reset_default_graph()
        config = tf.ConfigProto(allow_soft_placement=True,
                                intra_op_parallelism_threads=1,
                                inter_op_parallelism_threads=1)
        config.gpu_options.allow_growth = True
        self.sess = tf.Session(config=config)

        actor_network = ActorNetwork(self.sess, obs_dim, act_dim, self.kl_targ,
                                     self.hid1_mult, self.policy_logvar,
                                     self.clipping_range)
        critic_network = CriticNetwork(self.sess, obs_dim, self.hid1_mult)

        # initialize policy and value tensorflow graph
        actor_network.build_graph()
        critic_network.build_graph()
        self.init = tf.group(tf.global_variables_initializer(),
                             tf.local_variables_initializer())
        self.sess.run(self.init)

        # run a few episodes of untrained policy to initialize scaler:
        self.run_policy(actor_network, scaler, logger, episodes=5)
        episode = 0
        while episode < self.num_episodes:
            trajectories = self.run_policy(actor_network,
                                           scaler,
                                           logger,
                                           episodes=self.batch_size)
            episode += len(trajectories)
            self.add_value(trajectories,
                           critic_network)  # add estimated values to episodes
            self.add_disc_sum_rew(
                trajectories)  # calculated discounted sum of Rs
            self.add_gae(trajectories)  # calculate advantage

            # concatenate all episodes into single NumPy arrays
            observes = np.concatenate([t['observes'] for t in trajectories])
            actions = np.concatenate([t['actions'] for t in trajectories])
            disc_sum_rew = np.concatenate(
                [t['disc_sum_rew'] for t in trajectories])
            advantages = np.concatenate(
                [t['advantages'] for t in trajectories])
            # normalize advantages
            advantages = (advantages - advantages.mean()) / (advantages.std() +
                                                             1e-6)

            # add various stats to training log:
            self.log_batch_stats(observes, actions, advantages, disc_sum_rew,
                                 logger, episode)
            # update policy
            actor_network.backward(observes, actions, advantages, logger)
            # update value function
            critic_network.backward(observes, disc_sum_rew, logger)
            # write logger results to file and stdout
            logger.write(display=True)
        logger.close()
        actor_network.close_sess()
        critic_network.close_sess()
示例#14
0
from datetime import datetime, timedelta

from common.utils import Logger
from conditions.interface import get_or_create_condition, get_condition_by_name
from datas.interface import get_present_period, get_LB, get_vol_vol3, get_UB
from tasks.template.trade import TradeTask
from trade.interface import Trade
from trade.interface import market_buy, market_sell

logger_buy = Logger('market buy task')
logger_sell = Logger('market sell task')


class ShortFirstStep(TradeTask):
    RULE = '现价应该低于布林下轨 && 上一周期交易量大于前三周期交易量 && 市场买入'
    MSG_FORMAT = '现价 {} < 布林下轨: {} && 上一周期的交易总量{} > 上三周期的交易均值 {} && 买入比例为: {},货币对为: {}, 订单号为: {}'

    def try_pass(self, symbol, period, amount: int) -> bool:
        now = get_present_period(symbol=symbol).close
        lb = get_LB(symbol=symbol, period=period)

        if now >= lb:
            self.msg_args = (now, lb, None, None, None, None, None)
            return False
        ###############################################
        vol, vol3 = get_vol_vol3(symbol=symbol, period=period)

        # self.msg_args = (vol, vol3)
        if vol <= vol3:
            self.msg_args = (now, lb, vol, vol3, None, None, None)
            return False
示例#15
0
class TaskEngine:
    logger = Logger('Task engine')

    def __init__(self, task_id=None):
        if task_id is None:  # for test
            return

        task = Task.query.filter_by(id=task_id).first()
        self.task_id = None
        self.task_name = None
        self.kwargs = None
        self.run_time = None
        self.can_run_str = None
        self.task = None
        self.init_task(task)

        self.task_index = None
        self.mission_id = None
        self.mission_name = None
        self.missionary_id = None

    def __str__(self):
        return '[{}]'.format(self.task_name)

    def init_task(self, task):
        self.task_id = task.id
        self.task_name = task.name
        self.kwargs = task.kwargs
        self.run_time = task.run_time
        self.can_run_str = task.can_run
        self.task = task_template_dict[task.template_name]()
        self.logger.clue = self.task_name

    def get_info_from_mission(self, index: int = 0, mission_engine=None):
        # 0 is target
        self.task_index = index
        self.mission_id = mission_engine.mission_id
        self.mission_name = mission_engine.name
        self.missionary_id = mission_engine.missionary.id

    def can_run(self) -> bool:
        if not hasattr(self, '_can_run'):
            self._can_run = get_condition_by_name(self.can_run_str)
        return self._can_run()

    def try_pass(self) -> bool:
        self.logger.info('try to pass task')
        is_pass = self.task.try_pass(**self.kwargs)
        self.logger.info('task pass completed, result: {}'.format(is_pass))

        result = self.create_result_log(is_pass)

        if is_pass and issubclass(self.task.__class__, TradeTask):  # 记录交易结果
            order_id = self.task.order_id
            record_by_order_id.delay(order_id, self.mission_id,
                                     self.mission_name, self.missionary_id)
            self.logger.info(
                'task is type of trade, save trade record has in async.order id: {}'
                .format(order_id))
            self.logger.info(
                'task is type of trade, save trade record has in async.order id: {}'
                .format(order_id))
        return result

    def create_result_log(self, is_pass):
        result = Result.create_by_task_engine(self, is_pass)
        self.logger.info(
            'create result by task engine successfully, result: {}'.format(
                result))
        if self.task.need_record:
            result_log = ResultLog.save_from_task_engine(self)
            self.logger.info(
                'create result log by task engine successfully, result_log id: {}'
                .format(result_log.id))
            self.logger.info('task type: {}'.format(type(self.task)))

        return result
示例#16
0
from common.instance import celery
from common.utils import Logger
from missions.model import Missionary, Mission
from missions.engine import MissionaryEngine
# from schedule.interface import drop_all, mission_id_to_entry, add_missionary

log = Logger('mission interface')


def get_valid_mission_id_line():
    return Mission.get_valid_mission_id_line()


def get_valid_missions():
    return Mission.get_valid_missions()


def get_valid_mission_missionary():
    result = []
    mission_id_to_missionary = Missionary.get_valid_mission_id_to_missionary()
    mission_line = Mission.get_valid_missions()
    log.info('all valid mission: {}'.format(tuple(mission_line)))
    for mission in mission_line:
        missionary = mission_id_to_missionary.get(mission.id, None)
        if missionary:
            if mission.is_valid:
                result.append({'mission': mission, 'missionary': missionary})
        else:
            log.info('mission({})\'s missionary not found, create new'.format(
                mission))
            missionary = Missionary.create_by_mission(mission)
示例#17
0
from sqlalchemy import desc

from common.instance import db
from common.email_helper import send_error
from common.time_helper import now_format_time, now_int_timestamp
from common.utils import Logger


log = Logger('BASE MODEL')


def safe_commit():
    try:
        db.session.commit()
    except Exception as e:
        log.error("DB sql commit error, detail: {}, {}".format(e, e.args))
        send_error(e, 'db commit')
        db.session.rollback()


def save_many_models(*args):
    for model in args:
        db.session.add(model)
    safe_commit()


class UpdateModelBase(db.Model):
    __abstract__ = True
    id = db.Column(db.Integer, primary_key=True)
    time = db.Column(db.VARCHAR(64), default=now_format_time)
    timestamp = db.Column(db.Integer, default=now_int_timestamp)
示例#18
0
class MissionaryEngine:
    log = Logger('Missionary Engine')

    def __init__(self, mission: Mission, missionary: Missionary):
        self.mission_id = mission.id
        self.name = mission.name
        self.target_task_id = mission.target
        self.task_id_line = mission.task_line
        self.can_run_before_each_task = get_condition_by_name(
            mission.can_run_before_each_task)
        self.can_run_before_target = get_condition_by_name(
            mission.can_run_before_target)
        self.next_run_mission_id = mission.next_run_mission
        self.missionary = missionary
        self.log.clue = self.name

    @property
    def target(self):
        if not hasattr(self, '_target'):
            self._target = get_task_by_id(self.target_task_id)
            self.log.info('get target: {}'.format(self._target))
        return self._target

    def try_pass(self):
        self.log.info('try to pass')
        if self.can_run_before_each_task():
            if not self.find_next_task_and_run_each():
                return False
        else:
            self.log.info('cant run any task due to can_run_before_each_task')
            return False

        if self.can_run_before_target():
            if self.try_pass_target():
                self.missionary.finish()
                self.mission_finish()
                return True
        else:
            self.log.info('cant run target due to can_run_before_target')
            return False

    def find_next_task_and_run_each(self):
        for index, task_id in enumerate(self.task_id_line):
            if index < self.missionary.next_task_index:
                continue

            task = self.get_task_by_id(task_id, index)
            self.log.info(
                'run from index: {}, task_id {}, get task: {} '.format(
                    index, task_id, task))
            if not self.run_task(index, task):
                return False
            self.missionary_follow_task(self.missionary, task)

        self.log.info('pass all task successfully')
        return True

    def run_task(self, index, task) -> bool:
        if self.is_task_can_run_in_this_missionary(task, self.missionary):
            can_run_result = task.can_run()
            self.log.info(f'try to pass ')
            if can_run_result and task.try_pass():
                self.log.info('pass one task successfully')
                self.missionary.add_task_index(save=False)
                return True
        return False

    def try_pass_target(self):
        self.target.get_info_from_mission(mission_engine=self)
        if self.target.can_run() and self.target.try_pass():
            self.log.info('pass target successfully')
            self.log.info('update missionary model successfully')
            send_missionary_pass(self.name)
            return True
        return False

    def mission_finish(self):
        self.log.info('after all pass, ready to start next mission')
        finished_mission = Mission.query.filter(
            Mission.id == self.mission_id).first()
        finished_mission.is_valid = False
        finished_mission.save()
        # mission = Mission.query.filter_by(id=self.next_run_mission_id).first()
        # mission.is_valid = True
        # save_many_models(finished_mission, mission)
        # self.log.info('finished mission: {}, next mission: {}'.format(finished_mission, mission))

        # self.log.info('try to del finished mission')
        # del_missionary(finished_mission.id)
        # missionary = Missionary.get_or_create_by_mission(mission)
        # self.log.info('try to add next mission')
        # add_missionary(mission_id=mission.id, run_time=missionary.run_time)

    @classmethod
    def is_task_can_run_in_this_missionary(cls, task, missionary):
        if missionary.run_time == task.run_time:
            cls.log.info('missionary run time is same as task run time')
            return True

        if can_run_at_now(task.run_time, now=missionary.run_time):
            cls.log.info(
                f'missionary run time not same as task, but {task.run_time} can run at {missionary.run_time}'
            )
            return True

        cls.log.info('task({}) cant run at {}'.format(task.run_time,
                                                      missionary.run_time))
        return False

    @classmethod
    def missionary_follow_task(cls, missionary, task):
        if missionary.run_time != task.run_time:
            missionary.run_time = task.run_time  # if not same,follow to task
            missionary.save()
            cls.log.info('missionary run time update to {}'.format(
                missionary.run_time))
            return
        cls.log.info(f'missionary run time is same as task {task}')

    @classmethod
    def get_task_by_id(cls, task_id, index):
        task = get_task_by_id(task_id)
        task.get_info_from_mission(index, cls)
        return task

    #
    # @property
    # def next_task_index(self):
    #     redis.set()
    # def show_info(self):
    #     result = []
    #     for task_id in self.task_id_line:
    #         result.append(task().pass_(True))
    #     return result


#     def get_last_finished_task_id(self) -> int:
#         # 如果一条都找不到,或找到的已经is end=True,创建一条-1,否则返回正常值
#         start = time.time()
#         self.log.info('try to get_last_finished_task_id')
#         result = TaskPass.get_last(mission=self.NAME)
#         if result is None or result.is_end:
#             self.log.info('first run, result: {}'.format(result))
#             self.create_first_record()
#             return -1
#
#         self.start_time = result.mission_start
#         self.log.info('get_last_finished_task_id cost {}'.format(time.time() - start))
#         return result.task_id
#
#     def create_first_record(self):
#         self.start_time = now_format_time()
#         tp = TaskPass(mission_start=self.start_time, mission=self.NAME, task_id=-1)
#         tp.save()
#         self.log.info('create first recode, task pass: {}'.format(tp))
#
#     def save_pass_record(self, task):
#         if getattr(self, 'start_time', False):
#             mp = TaskPass(task=task.__class__.__name__,
#                           mission_start=self.start_time,
#                           mission=self.NAME,
#                           task_id=self.TASKLINE.index(task))
#             if mp.task_id == len(self.TASKLINE) - 1:
#                 mp.is_end = True
#             mp.save()
#         else:
#             self.create_first_record()
#
#     @classmethod
#     def run(cls):
#         if cls.can_run:
#             m = cls()
#             m.do_run()
#             return m
#         else:
#             raise Exception('cant run')
#
#     @property
#     def can_run(self):
#         return not Conditions.query.filter_by(valid=True).exists()
#
#     def do_run(self):
#         raise Exception('must be covered')
#
#
# class MissionForBuy(MissionBase):
#     Trade = None
#
#     def __init__(self, *args, **kwargs):
#         self.check_pre_init()
#         super().__init__(*args, **kwargs)
#
#     def check_pre_init(self):
#         if self.Trade is None:
#             raise Exception('Trade should not be None')
#
#     def do_run(self):
#         self.log.info('mission.start')
#         start = time.time()
#         if self.pass_():
#             order = self.Trade(self.NAME).buy()
#             trade = Trade.create_by_order(order, self.NAME)
#             send_trade('成功买入!', trade)
#         self.log.info('MissionForBuy.run cost {}'.format(time.time() - start))
#
#
# class MissionForSell(MissionBase):
#     Trade = None
#
#     def __init__(self, *args, **kwargs):
#         self.check_pre_init()
#         super().__init__(*args, **kwargs)
#
#     def check_pre_init(self):
#         if self.Trade is None:
#             raise Exception('Trade should not be None')
#
#     def do_run(self):
#         if self.pass_():
#             order = self.Trade(self.NAME).sell()
#             trade = Trade.create_by_order(order, self.NAME)
#             send_trade('成功卖出!', trade)
示例#19
0
from huobi.requstclient import RequestClient
from huobi.model.constant import OrderType, AccountType
from huobi.model import OrderState, Account

from common.utils import one_more_try, Logger
from configs import AccessKey, SecretKey, HUOBI_URL, TRADE_MATCH, TESTING, NOTTRADE

huobi_logger = Logger('huobi')

huobi_logger.info(f'AccessKey: {AccessKey}')
huobi_logger.info(f'SecretKey: {SecretKey}')


def get_request_client():
    return RequestClient(api_key=AccessKey,
                         secret_key=SecretKey,
                         url=HUOBI_URL)


@one_more_try('查询账户余额', max=3)
def get_all_spot_balance(type_):
    request_client = get_request_client()
    huobi_logger.info(
        f'request_client.get_account_balance_by_account_type({AccountType.SPOT})'
    )
    account: Account = request_client.get_account_balance_by_account_type(
        AccountType.SPOT)
    for balance in account.balances:
        huobi_logger.info(
            f'currency: {balance.currency}, balance: {balance.balance}')
        if type_ == balance.currency and balance.balance > 0.000000001:
示例#20
0
import requests
from typing import List

from common.utils import Logger
from common.utils import one_more_try
from configs import MARKET_KLINE_URL

log = Logger('query kline')


class Kline:
    def __init__(self, data: dict):
        self.time = data['id']
        self.amount = data['amount']
        self.count = data['count']
        self.open = data['open']
        self.close = data['close']
        self.low = data['low']
        self.high = data['high']
        self.vol = data['vol']


@one_more_try(message='request huobi kline data;请求火币数据')
def get_kline_data(symbol: str, period: str, size: int) -> List[Kline]:
    """
    发送查询,获得数据字典组成的列表。
    :param symbol:such as btcusdt, thetausdt
    :param period: 1min, 5min, 15min, 30min, 60min, 4hour, 1day, 1week, 1mon, 1year
    :param size: [1, 200]
    :return:list, sorted by time, the newest at first
    [{