示例#1
0
 def _preprocess_sample(self, sample_str):
     """
     preprocess each sample with the limitation of maximum length and pad each sample to maximum length
     :param sample_str: Str format of json data, "Dict{'token': List[Str], 'label': List[Str]}"
     :return: sample -> Dict{'token': List[int], 'label': List[int], 'token_len': int}
     """
     raw_sample = json.loads(sample_str)
     sample = {'token': [], 'label': []}
     for k in raw_sample.keys():
         if k == 'token':
             sample[k] = [
                 self.vocab.v2i[k].get(v.lower(), self.vocab.oov_index)
                 for v in raw_sample[k]
             ]
         else:
             sample[k] = []
             for v in raw_sample[k]:
                 if v not in self.vocab.v2i[k].keys():
                     logger.warning('Vocab not in ' + k + ' ' + v)
                 else:
                     sample[k].append(self.vocab.v2i[k][v])
     if not sample['token']:
         sample['token'].append(self.vocab.padding_index)
     if self.mode == 'TRAIN':
         assert sample['label'], 'Label is empty'
     else:
         sample['label'] = [0]
     sample['token_len'] = min(len(sample['token']), self.max_input_length)
     padding = [
         self.vocab.padding_index
         for _ in range(0, self.max_input_length - len(sample['token']))
     ]
     sample['token'] += padding
     sample['token'] = sample['token'][:self.max_input_length]
     return sample
示例#2
0
 def check_signature_algorithm(self):
     logger.info("Der verwendete Signaturalgorithmus ist : " +
                 str(self.cert.signature_algorithm_oid._name))
     logger.info("Die zugehörige OID lautet: " +
                 str(self.cert.signature_algorithm_oid.dotted_string))
     # ok wenn signaturhash in liste enthalten
     if self.cert.signature_algorithm_oid._name in self.sig_hashes_ok:
         logger.info("Das ist OK")
     else:
         logger.warning("Bitte mit Hilfe der Checkliste überprüfen")
示例#3
0
 def update_lr(self):
     """
     (callback function) update learning rate according to the decay weight
     """
     logger.warning('Learning rate update {}--->{}'.format(
         self.optimizer.param_groups[0]['lr'],
         self.optimizer.param_groups[0]['lr'] *
         self.config.train.optimizer.lr_decay))
     for param in self.optimizer.param_groups:
         param[
             'lr'] = self.config.train.optimizer.learning_rate * self.config.train.optimizer.lr_decay
    def check_basic_constraint(self):
        #Anforderung 2.2.5
        try:
            basic_constraint_extension = self.cert.extensions.get_extension_for_class(
                x509.BasicConstraints)
            logger.info("Das Zertifikat hat eine BasicContraint Extension")
            logger.warning("Der Inhalt der BasicContraint Extension ist: " +
                           str(basic_constraint_extension))

            #TODO: Die Extension könnte man noch nett auswerten.

        except Exception as err:
            logger.error("Das Zertifikat hat keine BasicContraint Extension")
示例#5
0
    def test_key_exchange(self):
        #Anforderung 2.4.1
        openssl_cmd_getcert = "echo | openssl s_client -msg -connect " + self.hostname + ":" + str(
            self.port
        ) + self.openssl_client_proxy_part + " | grep 'ServerKey' -A 5"

        proc = subprocess.Popen([openssl_cmd_getcert],
                                stdout=subprocess.PIPE,
                                stderr=subprocess.PIPE,
                                shell=True)
        (out, err) = proc.communicate()
        logger.warning(
            "Die Nachricht muss leider noch ausgewertet werden. Das ist das Einzige, was noch nicht funktioniert"
        )
        logger.warning(out)
    def check_cert_for_extended_keyusage(self):
        try:
            keyusage_extension = self.cert.extensions.get_extension_for_class(
                x509.ExtendedKeyUsage)
            # logger.info("Das Zertifikat hat eine ExtendedKeyUsage Extension mit den folgenden Eigenschaften")
            # logger.warning("serverAuth: "+ str(keyusage_extension.value.SERVER_AUTH))

            for usg in keyusage_extension.value._usages:
                logger.warning(
                    "Das Zertifikat hat eine ExtendedKeyUsage Extension mit den folgenden Eigenschaften"
                    + usg._name)

            #TODO: Ist das der richtige Wert?
        except Exception as err:
            print err
示例#7
0
    def check_cert_for_keyusage_ca(self):
        try:
            keyusage_extension = self.cert.extensions.get_extension_for_class(
                x509.KeyUsage)
            if keyusage_extension.critical and keyusage_extension.value.key_cert_sign and keyusage_extension.value.crl_sign:
                logger.info(
                    "Das Zertifikat hat die korrekten KeyUsage-Bits, das ist so OK."
                )
                logger.info("critical: " + str(keyusage_extension.critical))
                logger.info("key_cert_sign: " +
                            str(keyusage_extension.value.key_cert_sign))
                logger.info("crl_sign: " +
                            str(keyusage_extension.value.crl_sign))
            else:
                logger.error(
                    "Das Zertifikat hat abweichende KeyUsage-Bits, das ist nicht OK"
                )
                logger.warning("critical: " + str(keyusage_extension.critical))
                logger.warning("key_cert_sign: " +
                               str(keyusage_extension.value.key_cert_sign))
                logger.warning("crl_sign: " +
                               str(keyusage_extension.value.crl_sign))

        except Exception as err:
            if "No <class 'cryptography.x509.extensions.KeyUsage'> extension was found" in str(
                    err):
                logger.error("Es wurde keine keyUsage Extension gefunden")
            else:
                print err
    def check_cert_for_revocation(self):

        tmp_file = tempfile.NamedTemporaryFile(delete=False)
        tmp_file.write(self.cert.public_bytes(serialization.Encoding.PEM))
        tmp_file.close()

        try:
            crl_extension = self.cert.extensions.get_extension_for_class(
                x509.CRLDistributionPoints)
            logger.info(
                "Das Zertifikat hat eine CRLDistributionPoint Extension")

            openssl_cmd_getcert = "openssl verify -crl_check_all -CAfile " + self.ca_file + " " + tmp_file.name

            proc = subprocess.Popen([openssl_cmd_getcert],
                                    stdout=subprocess.PIPE,
                                    stderr=subprocess.PIPE,
                                    shell=True)
            (out, err) = proc.communicate()

            logger.warning(
                "Die Prüfung des Zertifikats gegen die CRL hat folgendes Ergebnis ergeben:"
            )
            logger.warning(out)
            logger.warning(err)

        except Exception as err:
            logger.error(
                "Fehler bei der Prüfung des Revocation status. Existiert keine CRLDistributionPoint Extension? Es folgt die Ausgabe des Fehlers"
            )
            logger.error(err)
示例#9
0
    def check_basic_constraint(self):
        #Anforderung 2.2.5
        try:
            basic_constraint_extension = self.cert.extensions.get_extension_for_class(
                x509.BasicConstraints)
            # test auf vorhandensein der critical constraint
            if basic_constraint_extension.critical:
                logger.info(
                    "Das Zertifikat hat eine als kritisch markierte BasicContraint Extension. Das ist so OK"
                )
                logger.info("Der Inhalt der BasicContraint Extension ist: " +
                            str(basic_constraint_extension))
            else:
                logger.error(
                    "Das Zertifikat hat eine nicht kritisch markierte BasicContraint Extension. Das ist nicht OK"
                )
                logger.warning(
                    "Der Inhalt der BasicContraint Extension ist: " +
                    str(basic_constraint_extension))

        except Exception as err:
            logger.error("Das Zertifikat hat keine BasicContraint Extension")
示例#10
0
    def check_cert_for_extended_keyusage(self):
        try:
            # liste der extended-keyusage extension auslesen
            keyusage_extension = self.cert.extensions.get_extension_for_class(
                x509.ExtendedKeyUsage)
            usg_list = []
            for usg in keyusage_extension.value._usages:
                usg_list.append(usg._name)

            # test auf serverAuth
            if "serverAuth" in usg_list:
                contains_serverauth = True
                logger.info(
                    "Das Zertifikat hat eine ExtendedKeyUsage Extension mit dem Eintrag serverAuth"
                )
            else:
                contains_serverauth = False
                logger.warning(
                    "Das Zertifikat hat eine ExtendedKeyUsage Extension mit den folgenden Eigenschaften:",
                    usg_list)

        except Exception as err:
            logger.error("Das Zertifikat hat keine ExtendedKeyUsage Extension")
            print err
示例#11
0
    def check_cert_for_revocation(self):

        # cn auslesen
        for entry in self.cert.subject._attributes:
            for attr in entry:
                if attr.oid._name == "commonName":
                    cn = attr.value
                    break

        # cn in tmpfilenamen einbetten
        tmp_file = tempfile.NamedTemporaryFile(prefix="tmp.checklist.%s" %
                                               (cn),
                                               delete=False)
        tmp_file.write(self.cert.public_bytes(serialization.Encoding.PEM))
        tmp_file.close()

        try:
            crl_extension = self.cert.extensions.get_extension_for_class(
                x509.CRLDistributionPoints)
            logger.info(
                "Das Zertifikat hat eine CRLDistributionPoint Extension")

            if self.ca_file:
                openssl_ca_opt = "-CAfile " + self.ca_file
            else:
                openssl_ca_opt = ""

            # Download der crl und Prüfung auf Zertifikatablauf
            openssl_cmd_getcert = "openssl verify -crl_check_all -crl_download " + openssl_ca_opt + " " + tmp_file.name

            proc = subprocess.Popen([openssl_cmd_getcert],
                                    stdout=subprocess.PIPE,
                                    stderr=subprocess.PIPE,
                                    shell=True)
            (out, err) = proc.communicate()

            logger.warning(
                "Die Prüfung des Zertifikats gegen die CRL hat folgendes Ergebnis ergeben:"
            )
            logger.warning(out)
            logger.warning(err)

        except Exception as err:
            logger.error(
                "Fehler bei der Prüfung des Revocation status. Existiert keine CRLDistributionPoint Extension? Es folgt die Ausgabe des Fehlers"
            )
            logger.error(err)
示例#12
0
    def check_cert_for_keyusage(self):
        try:
            keyusage_extension = self.cert.extensions.get_extension_for_class(
                x509.KeyUsage)
            logger.info(
                "Das Zertifikat hat eine KeyUsage Extension mit den folgenden Eigenschaften"
            )
            logger.warning("digital_signature: " +
                           str(keyusage_extension.value.digital_signature))
            logger.warning("key_cert_sign: " +
                           str(keyusage_extension.value.key_cert_sign))
            logger.warning("crl_sign: " +
                           str(keyusage_extension.value.crl_sign))

            #TODO: Man könnte die Werte auch gleich prüfen, allerdings ist das für CA Zertifkate anders und daher etwas komplizierter.

        except Exception as err:
            if "No <class 'cryptography.x509.extensions.KeyUsage'> extension was found" in str(
                    err):
                logger.error("Es wurde keine keyUsage Extension gefunden")
            else:
                print err
示例#13
0
    def test_supported_cipher_suites(self):
        #Anforderung 2.3.2/2.3.3/2.3.4
        #TODO: Funktioniert aktuell nur mit RSA
        crypto_type = "RSA"
        openssl_cmd_getcert = "openssl ciphers"
        proc = subprocess.Popen([openssl_cmd_getcert],
                                stdout=subprocess.PIPE,
                                stderr=subprocess.PIPE,
                                shell=True)
        (out, err) = proc.communicate()

        out = out.replace('\n', '').replace('\r', '')
        all_ciphers = out.split(":")
        all_ciphers = filter(None, all_ciphers)
        all_ciphers = filter(None, all_ciphers)

        for cipher in all_ciphers:
            try:
                cipher_list = [
                    x for x in self.cipher_suites
                    if x[1] == cipher and x[2] == crypto_type
                ]
                allowed = should = must = optional = False

                if len(cipher_list) == 0:
                    allowed = False
                elif cipher_list[0][3] == "MUST":
                    must = True
                    allowed = True
                elif cipher_list[0][3] == "SHOULD":
                    should = True
                    allowed = True
                elif cipher_list[0][3] == "OPTIONAL":
                    optional = True
                    allowed = True

                context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
                context.set_ciphers(cipher)
                context.verify_mode = ssl.CERT_REQUIRED
                context.check_hostname = True
                context.load_default_certs()

                ssl_sock = self.__connect_ssl_socket(context)
                priority = ssl_sock.cipher()[2]

                if not allowed:
                    logger.error(
                        "Server unterstützt verbotene cipher-suite: " +
                        cipher + " mit Priorität" + str(priority) +
                        " Das sollte nicht der Fall sein")

                elif must or should or optional:
                    logger.warning(cipher + " wird unterstützt mit Priorität" +
                                   str(priority) +
                                   ". Bitte in der Checkliste prüfen.")

            except ssl.SSLError as err:
                if len(err.args) > 1 and (
                        "SSLV3_ALERT_HANDSHAKE_FAILURE" in err.args[1]
                        or "NO_CIPHERS_AVAILABLE" in err.args[1]):
                    if must:
                        logger.error(
                            cipher +
                            " wird nicht unterstützt aber von der Checkliste gefordert"
                        )
                    else:
                        logger.info(
                            cipher +
                            " wird nicht unterstützt. Das scheint OK zu sein.")
                if len(err.args) == 1:
                    if must:
                        logger.error(
                            cipher +
                            " wird nicht unterstützt aber von der Checkliste gefordert"
                        )
                    else:
                        logger.info(
                            cipher +
                            " wird nicht unterstützt. Das scheint OK zu sein.")
示例#14
0
    def test_session_renegotiation(self):
        #Anforderung 2.5.1

        if self.ca_file:
            sslyze_ca_opt = "--ca_file=" + self.ca_file
        else:
            sslyze_ca_opt = ""
        if self.clientcert_file:
            sslyze_clientcert_opt = "--cert=" + self.clientcert_file + " --key=" + self.clientcert_file
        else:
            sslyze_clientcert_opt = ""

        openssl_cmd_getcert = "sslyze --reneg " + sslyze_ca_opt + " " + sslyze_clientcert_opt + " " + self.hostname + ":" + str(
            self.port) + self.sslyze_proxy_part
        proc = subprocess.Popen([openssl_cmd_getcert],
                                stdout=subprocess.PIPE,
                                stderr=subprocess.PIPE,
                                shell=True)
        (out, err) = proc.communicate()

        if "ClientCertificateRequested" in out:
            logger.warning("sslyze Fehler: ClientCertificateRequested")
            return

        if "Unhandled exception" in out:
            logger.warning("unbekannter Fehler bei Ausführung von sslyze")
            print(out)
            return

# Anmerkung: die sslyze-ausgabe client-initiated renegotiation: ok-rejected ist ein rein positives security-merkmal und dürfte nicht als regel-verstoß gewertet werden.
# bitte prüfen
        if "Client-initiated Renegotiation" in out:
            if "Client-initiated Renegotiation:    OK - Rejected" in out:
                logger.info(
                    "Server unterstützt client-initiierte session renegotiation nicht. Das ist so OK"
                )
                logger.warning(
                    " - in der upstream-version dieses prüftools wird dieser prüfpunkt als fehler gewertet. ggfs. diesen punkt klären."
                )
            else:
                logger.error(
                    "Server unterstützt client-initiierte session renegotiation. Das sollte nicht der Fall sein."
                )
        else:
            logger.warning("kein Ergebnis für Client-initiated Renegotiation")

# Anmerkung: die secure renegotiation ist eine rein positives security-merkmal und dürfte nicht als regel-verstoß gewertet werden.
# bitte prüfen
        if "Secure Renegotiation:" in out:
            if "Secure Renegotiation:              OK - Supported" in out:
                logger.info(
                    "Der Server unterstützt die sichere Form der renegotiaion. Das ist so OK."
                )
                logger.warning(
                    " - in der upstream-version dieses prüftools wird dieser prüfpunkt als fehler gewertet. ggfs. diesen punkt klären."
                )
            else:
                logger.warning(
                    "Der Server unterstützt die sichere Form der renegotiation nicht. Bitte im Detail prüfen."
                )
        else:
            logger.warning("kein Ergebnis für Secure Renegotiation")
示例#15
0
 def check_signature_algorithm(self):
     logger.warning("Der verwendete Signaturalgorithmus ist : " +
                    str(self.cert.signature_algorithm_oid._name))
     logger.warning("Die zugehörige OID lautet: " +
                    str(self.cert.signature_algorithm_oid.dotted_string))
     logger.warning("Bitte mit Hilfe der Checkliste überprüfen")
示例#16
0
    def test_supported_cipher_suites(self):
        #Anforderung 2.3.2/2.3.3/2.3.4
        #TODO: Funktioniert aktuell nur mit RSA
        crypto_type = "RSA"
        openssl_cmd_getcert = "openssl ciphers"
        proc = subprocess.Popen([openssl_cmd_getcert],
                                stdout=subprocess.PIPE,
                                stderr=subprocess.PIPE,
                                shell=True)
        (out, err) = proc.communicate()

        out = out.replace('\n', '').replace('\r', '')
        all_ciphers = out.split(":")
        all_ciphers = filter(None, all_ciphers)
        all_ciphers = filter(None, all_ciphers)

        for cipher in all_ciphers:
            cipher_list = [
                x for x in self.cipher_suites
                if x[1] == cipher and x[2] == crypto_type
            ]
            allowed = should = must = optional = False

            if len(cipher_list) == 0:
                allowed = False
            elif cipher_list[0][3] == "MUST":
                must = True
                allowed = True
            elif cipher_list[0][3] == "SHOULD":
                should = True
                allowed = True
            elif cipher_list[0][3] == "OPTIONAL":
                optional = True
                allowed = True

            context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
            context.set_ciphers(cipher)
            if self.insecure:
                context.verify_mode = ssl.CERT_NONE
                context.check_hostname = False
            else:
                context.verify_mode = ssl.CERT_REQUIRED
                context.check_hostname = True

            if self.ca_file:
                context.load_verify_locations(cafile=self.ca_file)
            else:
                context.load_default_certs()
            if self.clientcert_file:
                context.load_cert_chain(certfile=self.clientcert_file)

            try:
                s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
                ssl_sock = context.wrap_socket(s,
                                               server_hostname=self.hostname)
                ssl_sock.connect((self.hostname, self.port))
                priority = ssl_sock.cipher()[2]

                if not allowed:
                    logger.error(
                        "Server unterstützt verbotene cipher-suite: " +
                        cipher + " mit Priorität" + str(priority) +
                        " Das sollte nicht der Fall sein")

                elif must or should or optional:
                    logger.warning(cipher + " wird unterstützt mit Priorität" +
                                   str(priority) +
                                   ". Bitte in der Checkliste prüfen.")

            # Zertifikatfehler
            except ssl.CertificateError as err:
                logger.error("Zertifikatfehler bei Überprüfung von %s" %
                             (cipher))
                print(err)

            # ssl Verbindungsabbruch
            except ssl.SSLError as err:
                if len(err.args) > 1:
                    if "SSLV3_ALERT_HANDSHAKE_FAILURE" in err.args[
                            1] or "NO_CIPHERS_AVAILABLE" in err.args[
                                1] or "EOF occurred in violation of protocol" in err.args[
                                    1]:
                        if must:
                            logger.error(
                                cipher +
                                " wird nicht unterstützt aber von der Checkliste gefordert"
                            )
                        else:
                            logger.info(
                                cipher +
                                " wird nicht unterstützt. Das scheint OK zu sein."
                            )
                    # DH Key zu klein
                    elif "dh key too small" in err.args[1]:
                        logger.warn(cipher + " " + err.args[1])
                    # sonstiger Grund
                    else:
                        logger.warn(cipher +
                                    " verursacht einen Verbindungsfehler")
                        print(err.args[1])
                if len(err.args) == 1:
                    if must:
                        logger.error(
                            cipher +
                            " wird nicht unterstützt aber von der Checkliste gefordert"
                        )
                    else:
                        logger.info(
                            cipher +
                            " wird nicht unterstützt. Das scheint OK zu sein.")

            # socket Fehler
            except socket.error as err:
                if must:
                    logger.error(
                        cipher +
                        " wird nicht unterstützt aber von der Checkliste gefordert"
                    )
                else:
                    logger.info(
                        cipher +
                        " wird nicht unterstützt. Das scheint OK zu sein.")
示例#17
0
def train(config):
    """
    :param config: helper.configure, Configure Object
    """
    # loading corpus and generate vocabulary
    corpus_vocab = Vocab(config, min_freq=5, max_size=50000)

    # get data
    train_loader, dev_loader, test_loader = data_loaders(config, corpus_vocab)

    # build up model
    htcinfomax = HTCInfoMax(config, corpus_vocab, model_mode='TRAIN')
    htcinfomax.to(config.train.device_setting.device)
    # define training objective & optimizer
    criterion = ClassificationLoss(
        os.path.join(config.data.data_dir, config.data.hierarchy),
        corpus_vocab.v2i['label'],
        recursive_penalty=config.train.loss.recursive_regularization.penalty,
        recursive_constraint=config.train.loss.recursive_regularization.flag)
    optimize = set_optimizer(config, htcinfomax)

    # get epoch trainer
    trainer = Trainer(model=htcinfomax,
                      criterion=criterion,
                      optimizer=optimize,
                      vocab=corpus_vocab,
                      config=config)

    # set origin log
    best_epoch = [-1, -1]
    best_performance = [0.0, 0.0]
    model_checkpoint = config.train.checkpoint.dir
    model_name = config.model.type
    wait = 0
    if not os.path.isdir(model_checkpoint):
        os.mkdir(model_checkpoint)
    else:
        # loading previous checkpoint
        dir_list = os.listdir(model_checkpoint)
        dir_list.sort(key=lambda fn: os.path.getatime(
            os.path.join(model_checkpoint, fn)))
        latest_model_file = ''
        for model_file in dir_list[::-1]:
            if model_file.startswith('best'):
                continue
            else:
                latest_model_file = model_file
                break
        if os.path.isfile(os.path.join(model_checkpoint, latest_model_file)):
            logger.info('Loading Previous Checkpoint...')
            logger.info('Loading from {}'.format(
                os.path.join(model_checkpoint, latest_model_file)))
            best_performance, config = load_checkpoint(model_file=os.path.join(
                model_checkpoint, latest_model_file),
                                                       model=htcinfomax,
                                                       config=config,
                                                       optimizer=optimize)
            logger.info(
                'Previous Best Performance---- Micro-F1: {}%, Macro-F1: {}%'.
                format(best_performance[0], best_performance[1]))

    # train
    for epoch in range(config.train.start_epoch, config.train.end_epoch):
        start_time = time.time()
        trainer.train(train_loader, epoch)
        trainer.eval(train_loader, epoch, 'TRAIN')
        performance = trainer.eval(dev_loader, epoch, 'DEV')
        # saving best model and check model
        if not (performance['micro_f1'] >= best_performance[0]
                or performance['macro_f1'] >= best_performance[1]):
            wait += 1
            if wait % config.train.optimizer.lr_patience == 0:
                logger.warning(
                    "Performance has not been improved for {} epochs, updating learning rate"
                    .format(wait))
                trainer.update_lr()
            if wait == config.train.optimizer.early_stopping:
                logger.warning(
                    "Performance has not been improved for {} epochs, stopping train with early stopping"
                    .format(wait))
                break

        if performance['micro_f1'] > best_performance[0]:
            wait = 0
            logger.info('Improve Micro-F1 {}% --> {}%'.format(
                best_performance[0], performance['micro_f1']))
            best_performance[0] = performance['micro_f1']
            best_epoch[0] = epoch
            save_checkpoint(
                {
                    'epoch': epoch,
                    'model_type': config.model.type,
                    'state_dict': htcinfomax.state_dict(),
                    'best_performance': best_performance,
                    'optimizer': optimize.state_dict()
                }, os.path.join(model_checkpoint, 'best_micro_' + model_name))
        if performance['macro_f1'] > best_performance[1]:
            wait = 0
            logger.info('Improve Macro-F1 {}% --> {}%'.format(
                best_performance[1], performance['macro_f1']))
            best_performance[1] = performance['macro_f1']
            best_epoch[1] = epoch
            save_checkpoint(
                {
                    'epoch': epoch,
                    'model_type': config.model.type,
                    'state_dict': htcinfomax.state_dict(),
                    'best_performance': best_performance,
                    'optimizer': optimize.state_dict()
                }, os.path.join(model_checkpoint, 'best_macro_' + model_name))

        if epoch % 10 == 1:
            save_checkpoint(
                {
                    'epoch': epoch,
                    'model_type': config.model.type,
                    'state_dict': htcinfomax.state_dict(),
                    'best_performance': best_performance,
                    'optimizer': optimize.state_dict()
                },
                os.path.join(model_checkpoint,
                             model_name + '_epoch_' + str(epoch)))

        logger.info('Epoch {} Time Cost {} secs.'.format(
            epoch,
            time.time() - start_time))

    best_epoch_model_file = os.path.join(model_checkpoint,
                                         'best_micro_' + model_name)
    if os.path.isfile(best_epoch_model_file):
        load_checkpoint(best_epoch_model_file,
                        model=htcinfomax,
                        config=config,
                        optimizer=optimize)
        trainer.eval(test_loader, best_epoch[0], 'TEST')

    best_epoch_model_file = os.path.join(model_checkpoint,
                                         'best_macro_' + model_name)
    if os.path.isfile(best_epoch_model_file):
        load_checkpoint(best_epoch_model_file,
                        model=htcinfomax,
                        config=config,
                        optimizer=optimize)
        trainer.eval(test_loader, best_epoch[1], 'TEST')

    return