示例#1
0
    def setup_method(self):
        """
        Set up parameters for test methods
        """
        Logger.init()
        Logger.remove_file_logger()

        self.params = DmdsParams()

        # get one entry from the database
        Config.add_config('./config.ini')
        collection_details = ("local_mongodb", "depth", "driving_stereo")
        scenes = [
            "2018-10-26-15-24-18",
            "2018-10-19-09-30-39",
        ]
        self.train_data = []
        self.val_data = []
        self.collection_details = []

        # get ids
        for scene_token in scenes:
            td, vd = load_ids(collection_details,
                              data_split=(80, 20),
                              limit=100,
                              shuffle_data=False,
                              mongodb_filter={"scene_token": scene_token},
                              sort_by={"timestamp": 1})
            self.train_data.append(td)
            self.val_data.append(vd)
            self.collection_details.append(collection_details)
示例#2
0
    def _init_train_ops(self):
        self.n_step = self.args.n_step
        self.log_iters = self.args.log_iters
        self.learning_rate = self.args.lr
        self.gamma = self.args.gamma
        self.tau = self.args.tau
        self.vf_coef = self.args.vf_coef
        self.ent_coef = self.args.ent_coef
        self.threads = self.env.num_envs
        self.n_batch = self.n_step * self.threads
        self.train_iters = (self.args.timesteps // self.n_batch) + 1
        self.max_grad_norm = 0.5
        self.outdir = self.args.outdir
        if not os.path.exists(self.outdir):
            os.mkdir(self.outdir)

        if hasattr(self.args, 'game_lives'):
            self.rew_tracker = RewardTracker(self.threads,
                                             self.args.game_lives)
        else:
            self.rew_tracker = RewardTracker(self.threads, 1)
        self.logger = Logger(self.args.outdir)
        self.optimizer = optim.Adam(self.policy.parameters(),
                                    lr=self.learning_rate)

        # TODO: may not be the best way to do this
        headers = [
            "timestep", 'mean_rew', "best_mean_rew", "episodes", "policy_loss",
            "value_loss", "entropy", "time_elapsed"
        ]
        self.logger.set_headers(headers)
示例#3
0
  def _connect(self, http_method, rel_path, params, return_raw_response = False):
    url  = self._base_url + rel_path
    data = params
    
    if http_method == HttpMethod.GET:
      connect = self._http_connect_lib.get
    elif http_method == HttpMethod.PUT:
      connect = self._http_connect_lib.put
    elif http_method == HttpMethod.POST:
      connect = self._http_connect_lib.post
    elif http_method == HttpMethod.DELETE:
      connect = self._http_connect_lib.delete
    else:
      return None # not supported
    
    success, raw_response_json = False, None
    try:
      # TODO: verify certificate for Https connections
      raw_response_str = connect(url = url, data = data, verify = False)
      raw_response_json = json.loads(raw_response_str)
      success = True
    except Exception as e:
      Logger.printError('Failed to call API: %s, data: %s, error: %s'%(url, data, e.message))
      traceback.print_exc()
      sys.stdout.flush()
      pass

    if return_raw_response:
      return raw_response_str
    else:
      api_response = self._parseRawResponse(success, raw_response_json)
      return api_response
示例#4
0
    def _connect(self, method, url, data, verify):
        if method == HttpMethod.GET:
            if not data:
                full_url = url
            else:
                full_url = url + self._assembleQueryString(data)
            curl_cmd = 'curl "%s"' % (full_url)
        elif (method == HttpMethod.DELETE):
            # Need to add "X-Auth-User" header for DELETE request
            json_str = json.dumps(data)
            rectified_json_str = json_str.replace("'", '\u0027')
            curl_cmd = """curl -X %s -H "Content-Type: application/json" %s -d '%s'""" % (
                method, url, rectified_json_str)
        elif (method == HttpMethod.POST) or (method == HttpMethod.PUT):
            # Need to add "X-Auth-Token" header for POST and PUT request
            json_str = json.dumps(data)
            rectified_json_str = json_str.replace("'", '\u0027')
            try:
                curl_cmd = """curl -X %s -H "Content-Type: application/json" %s -d '%s'""" % (
                    method, url, rectified_json_str)
            except:
                curl_cmd_str = """curl -X {0} -H "Content-Type: application/json" {1} -d '{2}'""".format(
                    method, url, rectified_json_str)
                curl_cmd = unicode(curl_cmd_str, errors='ignore')
        else:
            raise Exception('Http method not supported! method: %s' % (method))

        Logger.printDebug('Curl command: %s' % (curl_cmd))

        proc = subprocess.Popen(curl_cmd,
                                stdout=subprocess.PIPE,
                                stderr=subprocess.PIPE,
                                shell=True)
        out, err = proc.communicate()
        return out
示例#5
0
 def _parseRawResponse(self, success, raw_response_json):
   if not success:
     Logger.printError('Invalid server response! raw_response_json: %s'%(raw_response_json))
     return HttpError(ApiErrorCode.SERVER_CONNECTION_ERROR,
       'Failed to connect to server!')
   status  = ApiStatus.OK
   body    = raw_response_json['result']
   return HttpSuccess(body)
示例#6
0
 def load(path_to_key_info_list_json):
     with open(path_to_key_info_list_json) as key_info_list_json_file:
         key_info_list_json = json.load(key_info_list_json_file)
     _key_info_map = {}
     for key_info_json in key_info_list_json:
         success, key_info = KeyInfo.fromJson(key_info_json)
         if not success:
             Logger.printError('Failed to parse key info: %s' %
                               (key_info_json))
             return False
         KeyInfoManager._key_info_map[key_info.alias] = key_info
     return True
  def run(self):
    try:
      while True:
          schedule.run_pending()
          time.sleep(1)
          
          if (self._max_job_invocation_count is not None) and \
             (self._job_invocation_count >= self._max_job_invocation_count):
            Logger.printInfo('Max job invocation count reached. Exiting...')
            break 

    except (KeyboardInterrupt, SystemExit):
      pass
    def load(path_to_config_json):
        with open(path_to_config_json) as config_json_file:
            config_json = json.load(config_json_file)

        config = ConfigManager.config
        try:
            config.load(config_json)
        except:
            Logger.printError('Failed to load config file! file path: %s %s' %
                              (path_to_config_json, traceback.print_exc()))
            return False

        return True
示例#9
0
    def __init__(self, brokers, currency):
        self.brokers = brokers
        self.currency = currency
        self.log = Logger()
        """
        first key = bidder exchange
        second key = asker exchange
        spread['CRYPTSY']['COINSE'] = -0.1 # raw profits (bid - ask)
        """
        self.prices = {}        # maintains hi_bids and lo_asks for each broker
        self.balances = {}      # base and alt balances for each exchange
        self.profit_spread = {} # price spreads with transaction fees applied
        self.profits = {}       # actual ALT profits, accounting for balances and volumes

        # self.update_balances()
        self.update_profit_spread() # automatically perform calculations upon initialization
示例#10
0
    def setup_method(self):
        Logger.init()
        Logger.remove_file_logger()

        self.params = CenternetParams(len(OD_CLASS_MAPPING))
        self.params.REGRESSION_FIELDS["l_shape"].active = True
        self.params.REGRESSION_FIELDS["3d_info"].active = True

        # get some entries from the database
        Config.add_config('./config.ini')
        self.collection_details = ("local_mongodb", "labels", "nuscenes_train")

        # Create Data Generators
        self.train_data, self.val_data = load_ids(self.collection_details,
                                                  data_split=(70, 30),
                                                  limit=250)
示例#11
0
    def setup_method(self):
        """
        Set up parameters for test methods
        """
        Logger.init()
        Logger.remove_file_logger()

        self.params = SemsegParams()

        # get one entry from the database
        Config.add_config('./config.ini')
        self.collection_details = ("local_mongodb", "labels", "comma10k")

        # Create Data Generators
        self.train_data, self.val_data = load_ids(self.collection_details,
                                                  data_split=(70, 30),
                                                  limit=30)
示例#12
0
    def setup_method(self):
        Logger.init()
        Logger.remove_file_logger()

        self.params = MultitaskParams(len(OD_CLASS_MAPPING.items()))

        # get one entry from the database
        Config.add_config('./config.ini')
        self.collection_details = ("local_mongodb", "labels", "nuscenes_train")

        # Create Data Generators
        self.td, self.vd = load_ids(
            self.collection_details,
            data_split=(70, 30),
            shuffle_data=True,
            limit=30
        )
示例#13
0
 def run(self):
     """进程执行方法"""
     self._log = Logger(PRJ_PATH)
     while True:
         target, type, msg = self._queue.get()
         if target == 'NOTICE':
             self._notice(type, msg)
         elif target == 'LOG':
             self._logger(type, msg)
示例#14
0
文件: Bot.py 项目: enplus/parttime
 def __init__(self, config, brokers):
     """
     config = configuration file
     brokers = array of broker objects
     """
     super(Bot, self).__init__()
     self.config = config
     self.brokers = brokers
     self.error = False
     self.log = Logger()
     self.backtest_data = None
     self.max_ticks = 0
     self.data_path = abspath(config.TICK_DIR)
     self.trading_enabled = True
     self.tick_i = 0
     self.debug = config.DEBUG
     self.debug_cls = config.DEBUG_CLEARSCREEN
     self.debug_mktdata = config.DEBUG_MARKETDATA
     self.debug_spread = config.DEBUG_SPREAD
示例#15
0
    def setup_method(self):
        """
        Set up parameters for test methods
        """
        Logger.init()
        Logger.remove_file_logger()

        self.params = CentertrackerParams(len(OD_CLASS_MAPPING))

        # get some entries from the database
        Config.add_config('./config.ini')
        self.collection_details = ("local_mongodb", "labels", "kitti")

        # Create Data Generators
        self.train_data, self.val_data = load_ids(
            self.collection_details,
            data_split=(70, 30),
            limit=100
        )
  def _executeJobs(self):
    if self._jobs_enabled:
      self._disableJobs()
      for job in self._job_list:
        try:
          
          job.execute(self._second_counter)

        except Exception as e:
          Logger.printError('Caught job exception: ' + str(e))
          traceback.print_exc()
          pass
        
        self._job_invocation_count += 1
    
      self._enableJobs()

    self._second_counter += self._wakeup_interval_in_seconds
    if self._second_counter >= JobScheduler.MAX_SECOND_COUNT:
      self._second_counter = 0
示例#17
0
def initialize(path_to_config_file, path_to_key_info_json):
    log = logging.getLogger('werkzeug')
    log.setLevel(logging.ERROR)
    success = ConfigManager.load(path_to_config_file)
    if not success:
        return False

    config = ConfigManager.getConfig()
    Logger.setLogFolder(config.server_log_file_folder)
    success = TransactionSigner.initialize()
    if not success:
        return False

    success = KeyInfoManager.load(path_to_key_info_json)
    if not success:
        return False

    if KeyInfoManager.getKeyInfo(KeyAlias.WHITELIST_CONTROLLER) == None:
        Logger.printError('%s key info not loaded!' %
                          (KeyAlias.WHITELIST_CONTROLLER))
        return False

    if KeyInfoManager.getKeyInfo(KeyAlias.EXCHANGE_RATE_CONTROLLER) == None:
        Logger.printError('%s key info not loaded!' %
                          (KeyAlias.EXCHANGE_RATE_CONTROLLER))
        return False

    success = TransactionSigner.initialize()
    if not success:
        return False

    return True
示例#18
0
    def _parseRawResponse(self, success, raw_response_json):
        if not success:
            Logger.printError(
                'Invalid server response! raw_response_json: %s' %
                (raw_response_json))
            return HttpError(ApiErrorCode.SERVER_CONNECTION_ERROR,
                             'Failed to connect to server!')

        if raw_response_json.get(ApiKey.STATUS, None) is None:
            # Likely to be HTTP error like 404, 500, etc.
            #       e.g.: {u'message': u'HTTP 404 Not Found', u'code': 404}
            Logger.printError(
                'Invalid server response! raw_response_json: %s' %
                (raw_response_json))
            return HttpError(ApiErrorCode.INVALID_RESPONSE_FORMAT,
                             'Response does not have a status code!')

        if raw_response_json.get(ApiKey.BODY, None) is None:
            Logger.printError(
                'Invalid server response! raw_response_json: %s' %
                (raw_response_json))
            return HttpError(ApiErrorCode.INVALID_RESPONSE_FORMAT,
                             'Response does not have a body!')

        status = raw_response_json[ApiKey.STATUS]
        body = raw_response_json[ApiKey.BODY]
        error_code = raw_response_json.get(ApiKey.ERROR_CODE, 0)
        message = raw_response_json.get(ApiKey.MESSAGE, "")

        if status in [ApiStatus.OK, ApiStatus.SUCCESS]:
            return HttpSuccess(body)
        else:
            return HttpError(error_code, message)
    def get(self):
        addresses = request.args.getlist(ApiKey.ADDRESS, None)
        nonce = request.args.get(ApiKey.NONCE, None)
        gas_price = request.args.get(ApiKey.GAS_PRICE, None)
        start_gas = request.args.get(ApiKey.START_GAS, None)
        if None in [addresses, nonce, gas_price, start_gas]:
            return HttpError(
                error_code=ApiErrorCode.QUERY_PARAM_MISSING,
                message=
                'Query parameter missing! requires address, nonce, gas_price, and start_gas!'
            ).toJson()

        nonce = int(nonce)
        gas_price = int(gas_price)
        start_gas = int(start_gas)
        function_name = SupportedFunctionName.ADD_ACCOUNTS_TO_WHITELIST
        function_params = addresses
        key_info = KeyInfoManager.getKeyInfo(KeyAlias.WHITELIST_CONTROLLER)

        success, signed_tx = TransactionSigner.signTransaction(
            from_addr=key_info.address,
            nonce=nonce,
            gas_price=gas_price,
            start_gas=start_gas,
            smart_contract_name=SmartContractName.THETA_TOKEN_SALE,
            function_name=function_name,
            function_params=function_params,
            private_key=key_info.private_key)

        if not success:
            return HttpError(
                error_code=ApiErrorCode.TX_SIGNING_FAILURE,
                message='Failed to sign the transaction!').toJson()

        Logger.printInfo('Signed transaction - nonce: %s, function_name: %s, function_params: %s, signed_tx: %s'%\
          (nonce, function_name, function_params, signed_tx))

        return HttpSuccess({ApiKey.SIGNED_TX: signed_tx}).toJson()
示例#20
0
    def setup_method(self):
        """
        Set up parameters for test methods
        """
        Logger.init()
        Logger.remove_file_logger()

        self.params = Params()

        # get one entry from the database
        Config.add_config('./config.ini')
        collection_details = ("local_mongodb", "labels", "nuscenes_train")

        # get ids
        td, vd = load_ids(
            collection_details,
            data_split=(70, 30),
            limit=100,
            shuffle_data=True,
        )
        self.train_data = [td]
        self.val_data = [vd]
        self.collection_details = [collection_details]
示例#21
0
    def load(path_to_config_json):
        with open(path_to_config_json) as config_json_file:
            config_json = json.load(config_json_file)

        config = ConfigManager._config
        try:
            config.port = config_json[ConfigKey.PORT]
            config.server_log_file_folder = config_json[
                ConfigKey.SERVER_LOG_FILE_FOLDER]
            config.theta_contract_address = config_json[
                ConfigKey.THETA_CONTRACT_ADDRESS]
            config.theta_abi_file_path = config_json[
                ConfigKey.THETA_ABI_FILE_PATH]
            config.theta_token_sale_contract_address = config_json[
                ConfigKey.THETA_TOKEN_SALE_CONTRACT_ADDRESS]
            config.theta_token_sale_abi_file_path = config_json[
                ConfigKey.THETA_TOKEN_SALE_ABI_FILE_PATH]
        except:
            Logger.printError('Failed to load config file! file path: %s %s' %
                              (path_to_config_json, traceback.print_exc()))
            return False

        ConfigManager._initialized = True
        return True
 def Export(self, start_height, end_height):
     Logger.printInfo("Start to extract events from height %s to %s..." %
                      (start_height, end_height))
     current_height = self.getProcessedHeight()
     if current_height >= start_height:
         Logger.printInfo("Already extracted up to height %s" %
                          (current_height))
         if current_height < end_height:
             Logger.printInfo("Continue from height %s..." %
                              (current_height + 1))
     else:
         current_height = start_height - 1
     while current_height < end_height:
         from_height = current_height + 1
         to_height = min(
             from_height + EthereumEventExtractor.HEIGHT_STEP - 1,
             end_height)
         self.export(from_height, to_height)
         current_height = to_height
         Logger.printInfo("Extracted events from height %s to %s" %
                          (from_height, to_height))
def sanityChecks(analyzed_balance_map, queried_balance_map,
                 expected_total_supply):
    total_supply = 0
    for address in analyzed_balance_map.keys():
        analyzed_balance = analyzed_balance_map[address]
        queried_balance = queried_balance_map.get(address, -1)
        if (analyzed_balance != queried_balance) or (queried_balance == -1):
            Logger.printError(
                "Balance mismatch for address: %s, analyzed_balance = %s, queried_balance = %s"
                % (address, analyzed_balance, queried_balance))
            return False
        total_supply += int(queried_balance)

    Logger.printInfo('Expected total supply  : %s' % (expected_total_supply))
    Logger.printInfo('Sum of queried balances: %s' % (total_supply))
    if total_supply != expected_total_supply:
        Logger.printError(
            'Token total supply mismatch. expected = %s, calculated = %s' %
            (expected_total_supply, total_supply))
        return False

    return True
示例#24
0
def main():
    global net
    global renderers
    global rendererIndex
    global vpns
    global vpnIndex
    if net:
        return

    renderers = []
    rendererIndex = {}
    vpns = []  # used in vpn.py only
    vpnIndex = {}  # used in vpn.py only

    random.seed(0)  # in order to get the same result to simplify debugging
    InitLogger()
    configFileName = None
    if len(sys.argv) > 1:
        configFileName = sys.argv[1]
        net = TestbedTopology(fileName=configFileName)
    else:
        net = TestbedTopology()
    # One-time setup for the VPN service
    wi = WanIntent("esnet", net.builder.wan)
    wr = WanRenderer(wi)
    wr.execute()
    renderers.append(wr)
    rendererIndex[wr.name] = wr

    for site in net.builder.sites:
        intent = SiteIntent(name=site.name, site=site)
        sr = SiteRenderer(intent)
        suc = sr.execute()  # no function without vpn information
        if not suc:
            Logger().warning('%r.execute() fail', sr)
        renderers.append(sr)
        rendererIndex[sr.name] = sr
    print "Now the demo environment is ready."
示例#25
0
 def Query(self, addresses, target_height):
     num_addresses = len(addresses)
     Logger.printInfo("Total number of addresses: %s" % (num_addresses))
     queried_balance_map = {}
     num_addresses_queried = 0
     for address in addresses:
         balance = self.ethereum_rpc_service.GetTokenBalance(
             smart_contract_address=self.smart_contract_address,
             account_address=address,
             target_height=hex(target_height))
         balance_str = str(balance)
         queried_balance_map[address] = balance_str
         num_addresses_queried += 1
         if num_addresses_queried % 1000 == 0:
             Logger.printInfo("%s addresses queried." %
                              (num_addresses_queried))
     Logger.printInfo("%s addresses queried." % (num_addresses_queried))
     return queried_balance_map
示例#26
0
class Result:
    logger = Logger('result')

    def __init__(self, msg: str, pass_: bool, rule: str, index: int, task_id: int,
                 task_name: str, mission_name: str):

        self.index = index
        self.task_id = task_id
        self.task_name = task_name
        self.mission_name = mission_name

        self.pass_ = pass_
        self.msg = msg
        self.rule = rule

    def __bool__(self):
        return self.pass_

    def __str__(self):
        result = '\n'.join((self.rule, self.msg, '<' + '='*50 + '>'))
        return result

    @classmethod
    def create_by_task_engine(cls, task_engine, pass_):
        cls.logger.info('try to create by task engine')
        task = task_engine.task
        if pass_:
            msg = 'PASS ' + task.msg
        else:
            msg = 'FAIL ' + task.msg
        return cls(msg, pass_, task.RULE, task_engine.task_index, task_engine.task_id,
                   task_engine.task_name, task_engine.mission_name)

    @property
    def short_str(self):
        return self.msg
示例#27
0
from huobi.requstclient import RequestClient
from huobi.model.constant import OrderType, AccountType
from huobi.model import OrderState, Account

from common.utils import one_more_try, Logger
from configs import AccessKey, SecretKey, HUOBI_URL, TRADE_MATCH, TESTING, NOTTRADE

huobi_logger = Logger('huobi')

huobi_logger.info(f'AccessKey: {AccessKey}')
huobi_logger.info(f'SecretKey: {SecretKey}')


def get_request_client():
    return RequestClient(api_key=AccessKey,
                         secret_key=SecretKey,
                         url=HUOBI_URL)


@one_more_try('查询账户余额', max=3)
def get_all_spot_balance(type_):
    request_client = get_request_client()
    huobi_logger.info(
        f'request_client.get_account_balance_by_account_type({AccountType.SPOT})'
    )
    account: Account = request_client.get_account_balance_by_account_type(
        AccountType.SPOT)
    for balance in account.balances:
        huobi_logger.info(
            f'currency: {balance.currency}, balance: {balance.balance}')
        if type_ == balance.currency and balance.balance > 0.000000001:
def train(func_train_one_batch, param_dict, path, log_dir_path):
    dis_loss = []
    pos_gen_loss = []
    neg_gen_loss = []

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    p = Logger(log_dir_path, **param_dict)

    # load data base
    if p.dataset is 'car196':
        data = Car196(root=path)
    else:
        print('DATASET is', p.dataset)
        data = CUB_200_2011(root=path)

    sampler = BalancedBatchSampler(data.train.label_to_indices,
                                   n_samples=p.n_samples,
                                   n_classes=p.n_classes)
    kwargs = {'num_workers': 6, 'pin_memory': True}

    train_loader = DataLoader(data.train, batch_sampler=sampler,
                              **kwargs)  # (5 * 98, 3, 224, 224)

    # train_iter = iter(train_loader)
    # batch = next(train_iter)
    # generate_random_triplets_from_batch(batch, p.n_samples, p.n_classes)

    test_loader = DataLoader(data.test, batch_size=p.batch_size)

    # construct the model
    model = ModifiedGoogLeNet(p.out_dim, p.normalize_hidden).to(device)
    model_pos_gen = Generator(p.out_dim, p.normalize_hidden).to(device)
    model_neg_gen = Generator(p.out_dim, p.normalize_output).to(device)

    model_dis = Discriminator(p.out_dim, p.out_dim).to(device)

    model_optimizer = optim.Adam(model.parameters(), lr=p.learning_rate)
    pos_gen_optimizer = optim.Adam(model_pos_gen.parameters(),
                                   lr=p.learning_rate)
    neg_gen_optimizer = optim.Adam(model_neg_gen.parameters(),
                                   lr=p.learning_rate)
    dis_optimizer = optim.Adam(model_dis.parameters(), lr=p.learning_rate)
    model_feat_optimizer = optim.Adam(model.parameters(), lr=p.learning_rate)

    time_origin = time.time()
    best_nmi_1 = 0.
    best_f1_1 = 0.
    best_nmi_2 = 0.
    best_f1_2 = 0.

    for epoch in range(p.num_epochs):
        time_begin = time.time()
        epoch_loss_neg_gen = 0
        epoch_loss_pos_gen = 0
        epoch_loss_dis = 0
        total = 0
        for batch in tqdm(train_loader, desc='# {}'.format(epoch)):
            triplet_batch = generate_random_triplets_from_batch(
                batch, n_samples=p.n_samples, n_class=p.n_classes)
            loss_pos_gen, loss_neg_gen, loss_dis = func_train_one_batch(
                device, model, model_pos_gen, model_neg_gen, model_dis,
                model_optimizer, model_feat_optimizer, pos_gen_optimizer,
                neg_gen_optimizer, dis_optimizer, p, triplet_batch, epoch)
            '''
            loss_dis = func_train_one_batch(device, model, model_dis, model_pos_gen
                                            model_optimizer,
                                            dis_optimizer, p, triplet_batch)
            '''

            epoch_loss_neg_gen += loss_neg_gen
            epoch_loss_pos_gen += loss_pos_gen
            epoch_loss_dis += loss_dis
            total += triplet_batch[0].size(0)

        loss_average_neg_gen = epoch_loss_neg_gen / total
        loss_average_pos_gen = epoch_loss_pos_gen / total
        loss_average_dis = epoch_loss_dis / total

        dis_loss.append(loss_average_dis)
        pos_gen_loss.append(loss_average_pos_gen)
        neg_gen_loss.append(loss_average_neg_gen)

        nmi, f1 = evaluate(device,
                           model,
                           model_dis,
                           test_loader,
                           epoch,
                           n_classes=p.n_classes,
                           distance=p.distance_type,
                           normalize=p.normalize_output,
                           neg_gen_epoch=p.neg_gen_epoch)
        if nmi > best_nmi_1:
            best_nmi_1 = nmi
            best_f1_1 = f1
            torch.save(model, os.path.join(p.model_save_path, "model.pt"))
            torch.save(model_pos_gen,
                       os.path.join(p.model_save_path, "model_pos_gen.pt"))
            torch.save(model_neg_gen,
                       os.path.join(p.model_save_path, "model_neg_gen.pt"))
            torch.save(model_dis,
                       os.path.join(p.model_save_path, "model_dis.pt"))
        if f1 > best_f1_2:
            best_nmi_2 = nmi
            best_f1_2 = f1

        time_end = time.time()
        epoch_time = time_end - time_begin
        total_time = time_end - time_origin

        print("#", epoch)
        print("time: {} ({})".format(epoch_time, total_time))
        print("[train] loss NEG gen:", loss_average_neg_gen)
        print("[train] loss POS gen:", loss_average_pos_gen)
        print("[train] loss dis:", loss_average_dis)
        print("[test]  nmi:", nmi)
        print("[test]  f1:", f1)
        print("[test]  nmi:", best_nmi_1, "  f1:", best_f1_1, "for max nmi")
        print("[test]  nmi:", best_nmi_2, "  f1:", best_f1_2, "for max f1")
        print(p)

    plt.plot(dis_loss)
    plt.ylabel("dis_loss")
    plt.savefig(log_dir_path + '/dis_loss.png')
    plt.close()

    plt.plot(pos_gen_loss)
    plt.ylabel("pos_gen_loss")
    plt.savefig(log_dir_path + '/pos_gen_loss.png')
    plt.close()

    plt.plot(neg_gen_loss)
    plt.ylabel("neg_gen_loss")
    plt.savefig(log_dir_path + '/neg_gen_loss.png')
    plt.close()
#      "expected_total_supply" : 1000000000000000000000000000
#    }
#

if __name__ == '__main__':
    if len(sys.argv) != 4:
        print(
            "\nUsage: python run.py <config_file_path> <target_height> <balance_file_path>\n"
        )
        exit(1)
    #Logger.enableDebugLog()

    config_file_path = sys.argv[1]
    target_height = int(sys.argv[2])
    balance_file_path = sys.argv[3]

    cfgMgr = ConfigManager()
    if not cfgMgr.load(config_file_path):
        Logger.printError('Failed to load config: %s' % (config_file_path))
        exit(1)

    config = cfgMgr.config
    ethereum_rpc_url = config.ethereum_rpc_url
    smart_contract_address = config.smart_contract_address
    expected_total_supply = config.expected_total_supply
    genesis_height = config.genesis_height

    exportTokenBalance(ethereum_rpc_url, smart_contract_address,
                       expected_total_supply, genesis_height, target_height,
                       balance_file_path)
def exportTokenBalance(ethereum_rpc_url, smart_contract_address,
                       expected_total_supply, genesis_height, target_height,
                       balance_file_path):
    export_folder = "./data/events"
    if not os.path.exists(export_folder):
        os.makedirs(export_folder)

    Logger.printInfo('')
    Logger.printInfo('Start exporting Ethereum events...')
    eee = EthereumEventExtractor(ethereum_rpc_url, smart_contract_address,
                                 export_folder)
    eee.Export(genesis_height, target_height)
    Logger.printInfo('Ethereum events exported.')
    Logger.printInfo('')

    Logger.printInfo('Start extracting token holders...')
    eea = EthereumEventAnalyzer()
    analyzed_balance_map = eea.Analyze(export_folder, target_height)
    Logger.printInfo('Token holders extracted.')
    Logger.printInfo('')

    #with open(balance_file_path + '.analyzed', 'w') as balance_file:
    #  json.dump(analyzed_balance_map, balance_file, indent=2)

    Logger.printInfo(
        'Start querying the balance of each holder at block height %s, may take a while...'
        % (target_height))
    token_holder_addresses = analyzed_balance_map.keys()
    tbe = TokenBalanceExtractor(ethereum_rpc_url, smart_contract_address)
    queried_balance_map = tbe.Query(token_holder_addresses, target_height)
    Logger.printInfo('Token holders balance retrieved.')
    Logger.printInfo('')

    #with open(balance_file_path + '.queried', 'w') as balance_file:
    #  json.dump(queried_balance_map, balance_file, indent=2)

    Logger.printInfo('Start sanity checks...')
    if not sanityChecks(analyzed_balance_map, queried_balance_map,
                        expected_total_supply):
        Logger.printError('Sanity checks failed.')
        exit(1)
    Logger.printInfo('Sanity checks all passed.')
    Logger.printInfo('')

    with open(balance_file_path, 'w') as balance_file:
        json.dump(queried_balance_map, balance_file, indent=2)

    Logger.printInfo('Token balances calculated and exported to: %s' %
                     (balance_file_path))
    Logger.printInfo('')