コード例 #1
0
def test_verify_cshd(caplog):
    test_verify_root = os.path.join(TESTS_DIR, "test_verify_files", "tt")
    # caplog.set_level sets on root logger by default which is somehow not the logger setup by
    # checksum_helper so specify our logger in the kw param
    caplog.set_level(logging.INFO, logger='Checksum_Helper')
    # ------------ 1 wrong crc, 1 missing ----------
    hfile_path = os.path.join(test_verify_root, "new_cshd_1+3+missing.cshd")
    a = Args(hash_file_name=[hfile_path])
    starting_cwd = os.getcwd()

    caplog.clear()
    # files_total, nr_matches, nr_missing, nr_crc_errors
    assert _cl_verify_hfile(a) == (3, 1, 1, 1)
    # cwd hasn't changed
    assert starting_cwd == os.getcwd()
    assert x_contains_all_y(caplog.record_tuples, [
        ('Checksum_Helper', logging.INFO, f'new_cshd.txt: SHA512 OK'),
        ('Checksum_Helper', logging.WARNING,
         f'sub1{os.sep}new_cshd_3.txt: SHA512 FAILED'),
        ('Checksum_Helper', logging.WARNING,
         f'sub2{os.sep}new_cshd_missing.txt: MISSING'),
    ])

    assert caplog.record_tuples[3:] == [
        ('Checksum_Helper', logging.WARNING,
         f'{test_verify_root}{os.sep}new_cshd_1+3+missing.cshd: 1 files with wrong CRCs!'
         ),
        ('Checksum_Helper', logging.WARNING,
         f'{test_verify_root}{os.sep}new_cshd_1+3+missing.cshd: 1 missing files!'
         ),
    ]
コード例 #2
0
def test_build_most_current_cshd(hash_fn_filter, search_depth,
                                 dont_filter_deleted, verified_cshd_name,
                                 setup_dir_to_checksum, monkeypatch):
    root_dir = setup_dir_to_checksum

    shutil.copy2(
        os.path.join(TESTS_DIR, "test_build_most_current_files",
                     "pre-existing.cshd"), root_dir)

    a = Args(path=root_dir,
             hash_filename_filter=hash_fn_filter,
             discover_hash_files_depth=search_depth,
             dont_filter_deleted=dont_filter_deleted,
             hash_algorithm="sha512",
             out_filename="most_current.cshd")
    _cl_build_most_current(a)

    verified_cshd_contents = read_file(
        os.path.join(TESTS_DIR, "test_build_most_current_files",
                     verified_cshd_name))
    generated_cshd_name = f"{root_dir}{os.sep}most_current.cshd"
    generated_cshd_contents = read_file(generated_cshd_name)

    print("VERIFIED:", verified_cshd_contents.strip())
    print("GEN:", generated_cshd_contents)
    assert (verified_cshd_contents == generated_cshd_contents)
コード例 #3
0
def test_copyto(setup_tmpdir_param, monkeypatch, caplog) -> None:
    tmpdir = setup_tmpdir_param
    root_dir = os.path.join(tmpdir, "tt")
    shutil.copytree(os.path.join(TESTS_DIR, "test_copyto_files", "tt"),
                    os.path.join(root_dir))

    with open(os.path.join(TESTS_DIR, "test_copyto_files", "tt", "tt.sha512"),
              'r',
              encoding='utf-8-sig') as f:
        orig = f.read().replace("\\", os.sep)

    # set input to automatically answer with y so we write the file when asked
    monkeypatch.setattr('builtins.input', lambda x: "y")
    a = Args(source_path=os.path.join(root_dir, "tt.sha512"),
             dest_path=f".{os.sep}sub2{os.sep}tt_moved.sha512")
    hf = _cl_copy(a)
    assert os.path.isfile(os.path.join(root_dir, "sub2", "tt_moved.sha512"))

    with open(os.path.join(root_dir, "sub2", "tt_moved.sha512"),
              'r',
              encoding='utf-8-sig') as f:
        moved = f.read()
    expected = orig.replace("*n", f"*../n").replace(f"*sub2/", "*").replace(
        "*sub1", f"*../sub1")
    assert expected == moved

    a = Args(source_path=os.path.join(root_dir, "sub2", "tt_moved.sha512"),
             dest_path=f"..{os.sep}sub1{os.sep}sub2{os.sep}tt_moved2.sha512")
    hf = _cl_copy(a)
    assert os.path.isfile(
        os.path.join(root_dir, "sub1", "sub2", "tt_moved2.sha512"))

    with open(os.path.join(root_dir, "sub1", "sub2", "tt_moved2.sha512"),
              'r',
              encoding='utf-8-sig') as f:
        moved = f.read()
    expected = orig.replace("*n", f"*../../n").replace(
        f"*sub2/", f"*../../sub2/").replace(f"*sub1/sub2/",
                                            "*").replace(f"*sub1/", f"*../")
    assert expected == moved

    a = Args(source_path=os.path.join(root_dir, "sub1", "sub2",
                                      "tt_moved2.sha512"),
             dest_path=f"..{os.sep}.{os.sep}tt_moved3.sha512")
    hf = _cl_copy(a)
    # not reading in the written file only making sure it was written to the correct loc
    assert os.path.isfile(os.path.join(root_dir, "sub1", "tt_moved3.sha512"))
コード例 #4
0
def test_do_incremental_per_dir(whitelist, blacklist, expected_dir,
                                setup_tmpdir_param):
    tmpdir = setup_tmpdir_param
    root_dir = os.path.join(tmpdir, "tt")
    shutil.copytree(
        os.path.join(TESTS_DIR, "test_incremental_files", "per_dir"),
        os.path.join(root_dir, ""))

    a = Args(path=root_dir,
             hash_filename_filter=None,
             single_hash=True,
             dont_include_unchanged=False,
             discover_hash_files_depth=-1,
             hash_algorithm="sha512",
             per_directory=True,
             whitelist=whitelist,
             blacklist=blacklist,
             skip_unchanged=False,
             dont_collect_mtime=False)
    _cl_incremental(a)

    expected_res = [
        ("root.sha512", f"tt_{time.strftime('%Y-%m-%d')}.sha512"),
        ("sub1.sha512",
         os.path.join("sub1", f"sub1_{time.strftime('%Y-%m-%d')}.sha512")),
        ("sub2.sha512",
         os.path.join("sub2", f"sub2_{time.strftime('%Y-%m-%d')}.sha512")),
        ("sub3.sha512",
         os.path.join("sub3", f"sub3_{time.strftime('%Y-%m-%d')}.sha512")),
        ("sub4.sha512",
         os.path.join("sub4", f"sub4_{time.strftime('%Y-%m-%d')}.sha512")),
    ]

    for expected_fn, result_fn in expected_res:
        try:
            verified_sha_contents = read_file(
                os.path.join(TESTS_DIR, "test_incremental_files", expected_dir,
                             expected_fn))
        except FileNotFoundError:
            # make sure generated file is also missing
            assert not os.path.exists(os.path.join(root_dir, result_fn))
            continue

        generated_sha_contents = read_file(os.path.join(root_dir, result_fn))

        compare_lines_sorted(verified_sha_contents, generated_sha_contents)
コード例 #5
0
def test_white_black_list(depth, hash_fn_filter, include_unchanged, whitelist,
                          blacklist, verified_sha_name, setup_tmpdir_param,
                          caplog, monkeypatch):
    tmpdir = setup_tmpdir_param
    root_dir = os.path.join(tmpdir, "wl_bl")
    # When using copytree, you need to ensure that src exists and dst does not exist.
    # Even if the top level directory contains nothing, copytree won't work because it
    # expects nothing to be at dst and will create the top level directory itself.
    shutil.copytree(os.path.join(TESTS_DIR, "test_incremental_files", "wl_bl"),
                    os.path.join(root_dir))
    caplog.clear()
    # caplog.set_level sets on root logger by default which is somehow not the logger setup by
    # checksum_helper so specify our logger in the kw param
    caplog.set_level(logging.WARNING, logger='Checksum_Helper')
    monkeypatch.setattr('builtins.input', lambda x: "y")

    a = Args(path=root_dir,
             hash_filename_filter=hash_fn_filter,
             single_hash=True,
             dont_include_unchanged=not include_unchanged,
             discover_hash_files_depth=depth,
             hash_algorithm="sha512",
             per_directory=False,
             whitelist=whitelist,
             blacklist=blacklist,
             skip_unchanged=False,
             dont_collect_mtime=False)
    _cl_incremental(a)
    if whitelist is not None and blacklist is not None:
        assert caplog.record_tuples == [
            ('Checksum_Helper', logging.ERROR,
             'Can only use either a whitelist or blacklist - not both!'),
        ]
    else:
        verified_sha_contents = read_file(
            os.path.join(TESTS_DIR, "test_incremental_files",
                         verified_sha_name))

        # find written sha (current date is appended)
        generated_sha_name = f"wl_bl_{time.strftime('%Y-%m-%d')}.sha512"
        generated_sha_contents = read_file(
            os.path.join(root_dir, generated_sha_name))

        compare_lines_sorted(verified_sha_contents, generated_sha_contents)
コード例 #6
0
def test_build_most_current_single(hash_fn_filter, search_depth,
                                   dont_filter_deleted, verified_sha_name,
                                   setup_dir_to_checksum, monkeypatch):
    root_dir = setup_dir_to_checksum

    a = Args(path=root_dir,
             hash_filename_filter=hash_fn_filter,
             discover_hash_files_depth=search_depth,
             dont_filter_deleted=dont_filter_deleted,
             hash_algorithm="sha512",
             out_filename="most_current.sha512")
    _cl_build_most_current(a)

    verified_sha_contents = read_file(
        os.path.join(TESTS_DIR, "test_build_most_current_files",
                     verified_sha_name))
    generated_sha_name = f"{root_dir}{os.sep}most_current.sha512"
    generated_sha_contents = read_file(generated_sha_name)

    assert (verified_sha_contents == generated_sha_contents)
コード例 #7
0

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        dimention = torch.arange(0, d_model)

        pe[:, 0::2] = torch.sin(
            position / torch.pow(10000, 2 *
                                 (dimention // 2)[::2].float() / d_model))
        pe[:, 1::2] = torch.cos(
            position / torch.pow(10000, 2 *
                                 (dimention // 2)[1::2].float() / d_model))
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)


if __name__ == "__main__":
    from utils import Args
    args = Args('train_data_path', 'save_path')
    args.vocab_size = 200
    model = Vae(args)
    model.generate(decode='greedy')
コード例 #8
0
import torch
import torch.nn as nn

import numpy as np
from utils import Args, EarlyStopping_unlearning
from losses.confusion_loss import confusion_loss

from losses.dice_loss import dice_loss
from sklearn.utils import shuffle
import torch.optim as optim
from train_utils_segmentation import train_unlearn, val_unlearn, train_encoder_unlearn, val_encoder_unlearn

import sys
########################################################################################################################
# Create an args class
args = Args()
args.channels_first = True
args.epochs = 300
args.batch_size = 3
args.diff_model_flag = False
args.alpha = 50
args.patience = 25
args.epoch_stage_1 = 100
args.epoch_reached = 1
args.beta = 10

cuda = torch.cuda.is_available()

LOAD_PATH_UNET = None
LOAD_PATH_SEGMENTER = None
LOAD_PATH_DOMAIN = None
import torch.nn as nn

import numpy as np
from sklearn.utils import shuffle
from utils import Args, EarlyStopping_unlearning
from losses.confusion_loss import confusion_loss

from losses.dice_loss import dice_loss

import torch.optim as optim
from train_utils_segmentation import train_encoder_domain_unlearn_semi, val_encoder_domain_unlearn_semi, train_unlearn_semi, val_unlearn_semi

import sys
########################################################################################################################
# Create an args class
args = Args()
args.channels_first = True
args.epochs = 300
args.batch_size = 4
args.diff_model_flag = False
args.alpha = 50
args.patience = 100

cuda = torch.cuda.is_available()

LOAD_PATH_UNET = None
LOAD_PATH_SEGMENTER = None
LOAD_PATH_DOMAIN = None

PRETRAIN_UNET = 'pretrain_unet'
PATH_UNET = 'unet_pth'
コード例 #10
0
import sys
import numpy as np
import gflags

import models
from utils import Accumulator, Args, make_batch

# PyTorch
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import torch.optim as optim

FLAGS = gflags.FLAGS
args = Args()

gflags.DEFINE_enum(
    "style", "dynamic",
    ["static", "static2", "dynamic", "dynamic2", "fakedynamic", "fakestatic"],
    "Specify dynamic or static RNN loops.")
gflags.DEFINE_boolean("smart_batching", True,
                      "Bucket batches for similar length.")

# Parse command line flags.
FLAGS(sys.argv)

# Set args.
args.training_data_path = 'trees/dev.txt'
args.eval_data_path = 'trees/dev.txt'
args.embedding_data_path = 'glove.6B.50d.txt'
コード例 #11
0
def test_verify_all(caplog):
    test_verify_root = os.path.join(TESTS_DIR, "test_verify_files", "tt")
    # caplog.set_level sets on root logger by default which is somehow not the logger setup by
    # checksum_helper so specify our logger in the kw param
    caplog.set_level(logging.INFO, logger='Checksum_Helper')
    # ------------ 2 wrong crc, 2 missing, MixedAlgo ----------
    root_dir = test_verify_root
    a = Args(root_dir=[root_dir],
             discover_hash_files_depth=1,
             hash_filename_filter=())

    caplog.clear()
    # files_total, nr_matches, nr_missing, nr_crc_errors
    assert _cl_verify_all(a) == (16, 10, 3, 3)
    assert x_contains_all_y(caplog.record_tuples, [
        ('Checksum_Helper', logging.INFO, f'new 2.txt: MD5 OK'),
        ('Checksum_Helper', logging.WARNING, f'new 3.txt: SHA512 FAILED'),
        ('Checksum_Helper', logging.INFO, f'new 4.txt: SHA512 OK'),
        ('Checksum_Helper', logging.INFO, f'sub1{os.sep}new 2.txt: SHA512 OK'),
        ('Checksum_Helper', logging.INFO, f'sub1{os.sep}new 3.txt: SHA512 OK'),
        ('Checksum_Helper', logging.INFO, f'sub1{os.sep}new 4.txt: SHA512 OK'),
        ('Checksum_Helper', logging.INFO,
         f'sub1{os.sep}sub2{os.sep}new 2.txt: SHA512 OK'),
        ('Checksum_Helper', logging.INFO,
         f'sub1{os.sep}sub2{os.sep}new 3.txt: SHA512 OK'),
        ('Checksum_Helper', logging.WARNING,
         f'sub1{os.sep}sub2{os.sep}new 4.txt: SHA512 FAILED'),
        ('Checksum_Helper', logging.INFO,
         f'sub2{os.sep}sub1{os.sep}file1.txt: SHA512 OK'),
        ('Checksum_Helper', logging.WARNING,
         f'sub5{os.sep}sub1{os.sep}file1.txt: MISSING'),
        ('Checksum_Helper', logging.WARNING,
         f'sub6{os.sep}file1.txt: MISSING'),
        ('Checksum_Helper', logging.INFO, f'new_cshd.txt: SHA512 OK'),
        ('Checksum_Helper', logging.WARNING,
         f'sub1{os.sep}new_cshd_3.txt: SHA512 FAILED'),
        ('Checksum_Helper', logging.WARNING,
         f'sub2{os.sep}new_cshd_missing.txt: MISSING'),
        ('Checksum_Helper', logging.INFO,
         f'sub3{os.sep}sub2{os.sep}new_cshd2.txt: SHA512 OK'),
    ])

    assert caplog.record_tuples[16:] == [
        ('Checksum_Helper', logging.WARNING,
         f"{root_dir}{os.sep}tt_most_current_{time.strftime('%Y-%m-%d')}.cshd: 3 files with wrong CRCs!"
         ),
        ('Checksum_Helper', logging.WARNING,
         f"{root_dir}{os.sep}tt_most_current_{time.strftime('%Y-%m-%d')}.cshd: 3 missing files!"
         ),
    ]

    # ------------ all matching, 1 missing, most_current single hash file ----------
    root_dir = os.path.join(test_verify_root, "sub1", "sub2")
    a = Args(root_dir=[root_dir],
             discover_hash_files_depth=0,
             hash_filename_filter=("*.cshd", ))

    caplog.clear()
    assert _cl_verify_all(a) == (4, 3, 1, 0)
    assert x_contains_all_y(caplog.record_tuples, [
        ('Checksum_Helper', logging.INFO, f'new 2.txt: SHA512 OK'),
        ('Checksum_Helper', logging.INFO, f'new 3.txt: SHA512 OK'),
        ('Checksum_Helper', logging.INFO, f'new 4.txt: SHA512 OK'),
        ('Checksum_Helper', logging.WARNING, f'new 8.txt: MISSING'),
    ])
    assert caplog.record_tuples[4:] == [
        ('Checksum_Helper', logging.INFO,
         f'{root_dir}{os.sep}sub2_most_current_{time.strftime("%Y-%m-%d")}.sha512: All files matching their hashes!'
         ),
        ('Checksum_Helper', logging.WARNING,
         f'{root_dir}{os.sep}sub2_most_current_{time.strftime("%Y-%m-%d")}.sha512: 1 missing files!'
         ),
    ]

    # ------------ 3 wrong crc, 4 missing ----------
    root_dir = test_verify_root
    a = Args(root_dir=[root_dir],
             discover_hash_files_depth=-1,
             hash_filename_filter=())

    caplog.clear()
    assert _cl_verify_all(a) == (20, 13, 4, 3)
    assert x_contains_all_y(caplog.record_tuples, [
        ('Checksum_Helper', logging.WARNING,
         "Found reference beyond the hash file's root dir in file: '%s'. "
         "Consider moving/copying the file using ChecksumHelper move/copy "
         "to the path that is the most common denominator!" %
         os.path.join(root_dir, "sub3", "sub2", "sub3_sub2.sha512")),
        ('Checksum_Helper', logging.INFO, f'new 2.txt: MD5 OK'),
        ('Checksum_Helper', logging.INFO, f'sub3{os.sep}file1.txt: SHA512 OK'),
        ('Checksum_Helper', logging.WARNING,
         f'sub3{os.sep}sub1{os.sep}file1.txt: SHA512 FAILED'),
        ('Checksum_Helper', logging.INFO,
         f'sub3{os.sep}sub2{os.sep}file1.txt: SHA512 OK'),
        ('Checksum_Helper', logging.WARNING, f'new 3.txt: SHA512 FAILED'),
        ('Checksum_Helper', logging.INFO, f'new 4.txt: SHA512 OK'),
        ('Checksum_Helper', logging.INFO, f'sub1{os.sep}new 2.txt: SHA512 OK'),
        ('Checksum_Helper', logging.INFO, f'sub1{os.sep}new 3.txt: SHA512 OK'),
        ('Checksum_Helper', logging.INFO, f'sub1{os.sep}new 4.txt: SHA512 OK'),
        ('Checksum_Helper', logging.INFO,
         f'sub1{os.sep}sub2{os.sep}new 2.txt: SHA512 OK'),
        ('Checksum_Helper', logging.INFO,
         f'sub1{os.sep}sub2{os.sep}new 3.txt: SHA512 OK'),
        ('Checksum_Helper', logging.INFO,
         f'sub1{os.sep}sub2{os.sep}new 4.txt: SHA512 OK'),
        ('Checksum_Helper', logging.INFO,
         f'sub2{os.sep}sub1{os.sep}file1.txt: SHA512 OK'),
        ('Checksum_Helper', logging.WARNING,
         f'sub5{os.sep}sub1{os.sep}file1.txt: MISSING'),
        ('Checksum_Helper', logging.WARNING,
         f'sub6{os.sep}file1.txt: MISSING'),
        ('Checksum_Helper', logging.WARNING,
         f'sub1{os.sep}sub2{os.sep}new 8.txt: MISSING'),
        ('Checksum_Helper', logging.INFO, f'new_cshd.txt: SHA512 OK'),
        ('Checksum_Helper', logging.WARNING,
         f'sub1{os.sep}new_cshd_3.txt: SHA512 FAILED'),
        ('Checksum_Helper', logging.WARNING,
         f'sub2{os.sep}new_cshd_missing.txt: MISSING'),
        ('Checksum_Helper', logging.INFO,
         f'sub3{os.sep}sub2{os.sep}new_cshd2.txt: SHA512 OK'),
    ])
    assert caplog.record_tuples[21:] == [
        ('Checksum_Helper', logging.WARNING,
         f"{root_dir}{os.sep}tt_most_current_{time.strftime('%Y-%m-%d')}.cshd: 3 files with wrong CRCs!"
         ),
        ('Checksum_Helper', logging.WARNING,
         f"{root_dir}{os.sep}tt_most_current_{time.strftime('%Y-%m-%d')}.cshd: 4 missing files!"
         ),
    ]

    # ------------ 2 wrong crc, 3 missing, single hash, md5+cshd filtered  ----------
    root_dir = test_verify_root
    # hash_filename_filter literally only filters out the hashfile if a str of
    # hash_filename_filter is in the name of the file without the extension
    a = Args(root_dir=[root_dir],
             discover_hash_files_depth=-1,
             hash_filename_filter=("*.md5", "*.cshd"))

    caplog.clear()
    assert _cl_verify_all(a) == (15, 10, 3, 2)
    assert x_contains_all_y(caplog.record_tuples, [
        ('Checksum_Helper', logging.WARNING,
         "Found reference beyond the hash file's root dir in file: '%s'. "
         "Consider moving/copying the file using ChecksumHelper move/copy "
         "to the path that is the most common denominator!" %
         os.path.join(root_dir, "sub3", "sub2", "sub3_sub2.sha512")),
        ('Checksum_Helper', logging.INFO, f'sub3{os.sep}file1.txt: SHA512 OK'),
        ('Checksum_Helper', logging.WARNING,
         f'sub3{os.sep}sub1{os.sep}file1.txt: SHA512 FAILED'),
        ('Checksum_Helper', logging.INFO,
         f'sub3{os.sep}sub2{os.sep}file1.txt: SHA512 OK'),
        ('Checksum_Helper', logging.WARNING, f'new 3.txt: SHA512 FAILED'),
        ('Checksum_Helper', logging.INFO, f'new 4.txt: SHA512 OK'),
        ('Checksum_Helper', logging.INFO, f'sub1{os.sep}new 2.txt: SHA512 OK'),
        ('Checksum_Helper', logging.INFO, f'sub1{os.sep}new 3.txt: SHA512 OK'),
        ('Checksum_Helper', logging.INFO, f'sub1{os.sep}new 4.txt: SHA512 OK'),
        ('Checksum_Helper', logging.INFO,
         f'sub1{os.sep}sub2{os.sep}new 2.txt: SHA512 OK'),
        ('Checksum_Helper', logging.INFO,
         f'sub1{os.sep}sub2{os.sep}new 3.txt: SHA512 OK'),
        ('Checksum_Helper', logging.INFO,
         f'sub1{os.sep}sub2{os.sep}new 4.txt: SHA512 OK'),
        ('Checksum_Helper', logging.INFO,
         f'sub2{os.sep}sub1{os.sep}file1.txt: SHA512 OK'),
        ('Checksum_Helper', logging.WARNING,
         f'sub5{os.sep}sub1{os.sep}file1.txt: MISSING'),
        ('Checksum_Helper', logging.WARNING,
         f'sub6{os.sep}file1.txt: MISSING'),
        ('Checksum_Helper', logging.WARNING,
         f'sub1{os.sep}sub2{os.sep}new 8.txt: MISSING'),
    ])

    assert caplog.record_tuples[16:] == [
        ('Checksum_Helper', logging.WARNING,
         f"{root_dir}{os.sep}tt_most_current_{time.strftime('%Y-%m-%d')}.sha512: 2 files with wrong CRCs!"
         ),
        ('Checksum_Helper', logging.WARNING,
         f"{root_dir}{os.sep}tt_most_current_{time.strftime('%Y-%m-%d')}.sha512: 3 missing files!"
         ),
    ]
コード例 #12
0
def test_move_files(src, dst, depth, hash_fn_filter, expected_files_dirname,
                    moved_paths, extra_cmp, setup_dir_to_checksum, caplog,
                    monkeypatch):
    root_dir = setup_dir_to_checksum

    # abspath test windows only
    if os.name != 'nt' and os.path.isabs(dst):
        return

    caplog.set_level(logging.WARNING)
    # clear logging records
    caplog.clear()
    a = Args(root_dir=root_dir,
             hash_filename_filter=hash_fn_filter,
             discover_hash_files_depth=depth,
             source_path=src,
             mv_path=dst)
    _cl_move(a)
    if src == os.path.join("sub4", "file1.txt") and dst == os.path.join(
            "sub2", "sub1"):
        # filter out hash file pardir warning
        assert [
            rt for rt in caplog.record_tuples
            if not rt[2].startswith("Found reference beyond")
        ] == [
            ('Checksum_Helper', logging.ERROR, "File %s already exists!" %
             (os.path.join(root_dir, "sub2", "sub1", "file1.txt"), )),
        ]
    elif src == os.path.join("sub4", "file1.txt") and dst == os.path.join(
            "sub2", "sub1", "file1.txt"):
        assert [
            rt for rt in caplog.record_tuples
            if not rt[2].startswith("Found reference beyond")
        ] == [
            ('Checksum_Helper', logging.ERROR, "File %s already exists!" %
             (os.path.join(root_dir, "sub2", "sub1", "file1.txt"), )),
        ]
    elif src == "sub1" and dst == "sub2":
        assert [
            rt for rt in caplog.record_tuples
            if not rt[2].startswith("Found reference beyond")
        ] == [
            ('Checksum_Helper', logging.ERROR,
             "Couldn't move file(s): Destination path '%s' already exists" %
             (os.path.join(root_dir, "sub2", "sub1"), )),
        ]
    elif src == "sub1" and dst == os.path.join("sub1", "sub5"):
        assert [
            rt for rt in caplog.record_tuples
            if not rt[2].startswith("Found reference beyond")
        ] == [
            ('Checksum_Helper', logging.ERROR,
             "Couldn't move file(s): Cannot move a directory '%s' into itself '%s'."
             % (os.path.join(root_dir,
                             "sub1"), os.path.join(root_dir, "sub1", "sub5"))),
        ]
    elif src == os.path.join("sub3", "sub1") and dst == "sub4":
        assert [
            rt for rt in caplog.record_tuples
            if not rt[2].startswith("Found reference beyond")
        ] == [
            ('Checksum_Helper', logging.ERROR,
             "Couldn't move file(s): Destination path '%s' already exists" %
             (os.path.join(root_dir, "sub4", "sub1"))),
        ]
    elif src == os.path.join("sub4", "file1.txt") and dst == os.path.join(
        [d for d in "abcdefghijklmnopqrstuvwxyz" if d != TESTS_DIR[0].lower()
         ][0] + ":", os.sep, "sub2", "sub1"):
        assert [
            rt for rt in caplog.record_tuples
            if not rt[2].startswith("Found reference beyond")
        ] == [
            ('Checksum_Helper', logging.ERROR,
             "Can't move files to a different drive than the hash files "
             "that hold their hashes!"),
        ]
    else:
        # only check that src doesnt exist when we shouldn't error
        assert not os.path.exists(src)

    # check that files are at the expected locations
    assert all(
        [os.path.exists(os.path.join(root_dir, p)) for p in moved_paths])

    # check that file paths were moved inside of hash files
    for hf_name in [hf for hf in HASH_FILES if not hf.startswith(src)]:
        assert (sort_hf_contents(
            read_file(os.path.join(root_dir, hf_name),
                      encoding="UTF-8-SIG")) == sort_hf_contents(
                          read_file(os.path.join(test_modf_dir_abs,
                                                 expected_files_dirname,
                                                 os.path.basename(hf_name)),
                                    encoding="UTF-8-SIG")))
    # additional compares when a hash file was moved
    for hf_current, hf_expected in extra_cmp:
        assert (sort_hf_contents(
            read_file(os.path.join(root_dir, hf_current),
                      encoding="UTF-8-SIG")) == sort_hf_contents(
                          read_file(os.path.join(test_modf_dir_abs,
                                                 expected_files_dirname,
                                                 hf_expected),
                                    encoding="UTF-8-SIG")))
コード例 #13
0
def main(env_name):

    # 获取所有参数
    args = Args(env_name)
    env = args.env
    max_epochs = args.max_epochs
    max_timesteps = args.max_timesteps
    update_timestep = args.update_timestep
    print_interval = args.print_interval

    # 初始化memory
    memory = Memory()

    # 创建agent实例
    agent = Agent(input_size=args.input_size,
                  output_size=args.output_size,
                  hidden_size=args.hidden_size,
                  lr=args.lr,
                  beta=args.beta,
                  gamma=args.gamma,
                  update_epoch=args.update_epoch,
                  epsilon=args.epsilon)

    reward_plot = [0]  #记录每print_interval个epoch的平均reward 画图用
    timestep_count = 0  #记录步长 到update_timestep清零
    interval_reward = 0  #记录每print_interval个epoch的平均reward 后清零
    interval_timestep = 0  #记录每print_interval个epoch的平均步长 后清零

    file_name = 'RL_Proj_2/{}.txt'.format(args.env_name)

    # training loop
    for epoch in range(1, max_epochs + 1):
        state = env.reset()  #与env交互随机获取一个state
        # agent做出action
        for timestep in range(max_timesteps):
            timestep_count += 1

            # old policy sampling 做出action 与环境交互
            action = agent.old_policy.act(state, memory)
            state, reward, done, _ = env.step(action)
            memory.rewards.append(reward)
            memory.is_done.append(done)

            # 判断是否需要更新 policy
            if timestep_count % update_timestep == 0:
                agent.update(memory)
                memory.clear_memory()
                timestep_count = 0

            interval_reward += reward
            env.render()
            if done:
                break

        interval_timestep += timestep

        # 每print_interval打印一次数据
        if epoch % print_interval == 0:
            interval_timestep = np.divide(interval_timestep, print_interval)
            interval_reward = np.divide(interval_reward, print_interval)

            reward_plot.append(interval_reward)

            # 储存数据
            with open(file_name, 'a') as f:
                f.write(
                    str(epoch) + ' ' + str(interval_timestep) + ' ' +
                    str(interval_reward) + '\n')

            print('Epoch {} \t average timestep: {} \t reward: {}'.format(
                epoch, interval_timestep, interval_reward))

            interval_reward = 0
            interval_timestep = 0

    # 训练结束后 存储模型
    torch.save(agent.policy.state_dict(),
               'RL_Proj_2/{}.pth'.format(args.env_name))

    #画图
    plt.plot(reward_plot)
    plt.xlabel('Epoch = tick times {}'.format(print_interval))
    plt.ylabel('Reward')
    plt.savefig('RL_Proj_2/{}.png'.format(args.env_name))
    plt.show()
コード例 #14
0
from datasets.numpy_dataset import numpy_dataset_three, numpy_dataset
from torch.utils.data import DataLoader
import torch
import torch.nn as nn

import numpy as np
from sklearn.utils import shuffle
from utils import Args, EarlyStopping_unlearning
from losses.confusion_loss import confusion_loss
import torch.optim as optim
from train_utils import train_unlearn_distinct, val_unlearn_distinct, val_encoder_domain_unlearn_distinct, train_encoder_domain_unlearn_distinct
import sys

########################################################################################################################
# Create an args class
args = Args()
args.channels_first = True
args.epochs = 300
args.batch_size = 16
args.diff_model_flag = False
args.alpha = 1
args.patience = 150
args.learning_rate = 1e-4

LOAD_PATH_ENCODER = None
LOAD_PATH_REGRESSOR = None
LOAD_PATH_DOMAIN = None

PRE_TRAIN_ENCODER = 'pretrain_encoder'
PATH_ENCODER = 'encoder_pth'
CHK_PATH_ENCODER = 'encoder_chk_pth'
コード例 #15
0
def test_verify_filter(caplog):
    test_verify_root = os.path.join(TESTS_DIR, "test_verify_files", "tt")
    # caplog.set_level sets on root logger by default which is somehow not the logger setup by
    # checksum_helper so specify our logger in the kw param
    caplog.set_level(logging.INFO, logger='Checksum_Helper')
    # ------------ 3 wrong crc, no missing, MixedAlgo ----------
    root_dir = test_verify_root
    a = Args(root_dir=root_dir,
             discover_hash_files_depth=1,
             hash_filename_filter=(),
             filter=[
                 f"sub1{os.sep}*",
                 "new ?.txt",
             ])

    caplog.clear()
    _cl_verify_filter(a)
    assert x_contains_all_y(caplog.record_tuples, [
        ('Checksum_Helper', logging.INFO, f'new 2.txt: MD5 OK'),
        ('Checksum_Helper', logging.WARNING, f'new 3.txt: SHA512 FAILED'),
        ('Checksum_Helper', logging.INFO, f'new 4.txt: SHA512 OK'),
        ('Checksum_Helper', logging.INFO, f'sub1{os.sep}new 2.txt: SHA512 OK'),
        ('Checksum_Helper', logging.INFO, f'sub1{os.sep}new 3.txt: SHA512 OK'),
        ('Checksum_Helper', logging.INFO, f'sub1{os.sep}new 4.txt: SHA512 OK'),
        ('Checksum_Helper', logging.INFO,
         f'sub1{os.sep}sub2{os.sep}new 2.txt: SHA512 OK'),
        ('Checksum_Helper', logging.INFO,
         f'sub1{os.sep}sub2{os.sep}new 3.txt: SHA512 OK'),
        ('Checksum_Helper', logging.WARNING,
         f'sub1{os.sep}sub2{os.sep}new 4.txt: SHA512 FAILED'),
        ('Checksum_Helper', logging.WARNING,
         f'sub1{os.sep}new_cshd_3.txt: SHA512 FAILED'),
    ])

    assert caplog.record_tuples[10:] == [
        ('Checksum_Helper', logging.WARNING,
         f"{root_dir}{os.sep}tt_most_current_{time.strftime('%Y-%m-%d')}.cshd: 3 files with wrong CRCs!"
         ),
        ('Checksum_Helper', logging.INFO,
         f"{root_dir}{os.sep}tt_most_current_{time.strftime('%Y-%m-%d')}.cshd: No missing files!"
         ),
    ]

    # ------------ 1 crc err, 2 missing, MixedAlgo ----------
    root_dir = test_verify_root
    a = Args(root_dir=root_dir,
             discover_hash_files_depth=-1,
             hash_filename_filter=(),
             filter=[
                 "*file?.txt",
                 f"s*{os.sep}sub1{os.sep}**",
             ])

    caplog.clear()
    _cl_verify_filter(a)
    assert x_contains_all_y(caplog.record_tuples, [
        ('Checksum_Helper', logging.WARNING,
         "Found reference beyond the hash file's root dir in file: '%s'. "
         "Consider moving/copying the file using ChecksumHelper move/copy "
         "to the path that is the most common denominator!" %
         os.path.join(root_dir, "sub3", "sub2", "sub3_sub2.sha512")),
        ('Checksum_Helper', logging.INFO, f'sub3{os.sep}file1.txt: SHA512 OK'),
        ('Checksum_Helper', logging.WARNING,
         f'sub3{os.sep}sub1{os.sep}file1.txt: SHA512 FAILED'),
        ('Checksum_Helper', logging.INFO,
         f'sub3{os.sep}sub2{os.sep}file1.txt: SHA512 OK'),
        ('Checksum_Helper', logging.INFO,
         f'sub2{os.sep}sub1{os.sep}file1.txt: SHA512 OK'),
        ('Checksum_Helper', logging.WARNING,
         f'sub5{os.sep}sub1{os.sep}file1.txt: MISSING'),
        ('Checksum_Helper', logging.WARNING,
         f'sub6{os.sep}file1.txt: MISSING'),
    ])

    assert caplog.record_tuples[7:] == [
        ('Checksum_Helper', logging.WARNING,
         f"{root_dir}{os.sep}tt_most_current_{time.strftime('%Y-%m-%d')}.cshd: 1 files with wrong CRCs!"
         ),
        ('Checksum_Helper', logging.WARNING,
         f"{root_dir}{os.sep}tt_most_current_{time.strftime('%Y-%m-%d')}.cshd: 2 missing files!"
         ),
    ]

    # ------------ 2 wrong crc, 4 missing, HashFile, md5 filtered  ----------
    root_dir = test_verify_root
    # hash_filename_filter literally only filters out the hashfile if a str of
    # hash_filename_filter is in the name of the file without the extension
    a = Args(root_dir=root_dir,
             discover_hash_files_depth=-1,
             hash_filename_filter=("*.md5", ),
             filter=[
                 "",
                 f"sub?{os.sep}*",
             ])

    caplog.clear()
    _cl_verify_filter(a)
    assert x_contains_all_y(caplog.record_tuples, [
        ('Checksum_Helper', logging.WARNING,
         "Found reference beyond the hash file's root dir in file: '%s'. "
         "Consider moving/copying the file using ChecksumHelper move/copy "
         "to the path that is the most common denominator!" %
         os.path.join(root_dir, "sub3", "sub2", "sub3_sub2.sha512")),
        ('Checksum_Helper', logging.INFO, f'sub3{os.sep}file1.txt: SHA512 OK'),
        ('Checksum_Helper', logging.WARNING,
         f'sub3{os.sep}sub1{os.sep}file1.txt: SHA512 FAILED'),
        ('Checksum_Helper', logging.INFO,
         f'sub3{os.sep}sub2{os.sep}file1.txt: SHA512 OK'),
        ('Checksum_Helper', logging.INFO, f'sub1{os.sep}new 2.txt: SHA512 OK'),
        ('Checksum_Helper', logging.INFO, f'sub1{os.sep}new 3.txt: SHA512 OK'),
        ('Checksum_Helper', logging.INFO, f'sub1{os.sep}new 4.txt: SHA512 OK'),
        ('Checksum_Helper', logging.INFO,
         f'sub1{os.sep}sub2{os.sep}new 2.txt: SHA512 OK'),
        ('Checksum_Helper', logging.INFO,
         f'sub1{os.sep}sub2{os.sep}new 3.txt: SHA512 OK'),
        ('Checksum_Helper', logging.INFO,
         f'sub1{os.sep}sub2{os.sep}new 4.txt: SHA512 OK'),
        ('Checksum_Helper', logging.INFO,
         f'sub2{os.sep}sub1{os.sep}file1.txt: SHA512 OK'),
        ('Checksum_Helper', logging.WARNING,
         f'sub5{os.sep}sub1{os.sep}file1.txt: MISSING'),
        ('Checksum_Helper', logging.WARNING,
         f'sub6{os.sep}file1.txt: MISSING'),
        ('Checksum_Helper', logging.WARNING,
         f'sub1{os.sep}sub2{os.sep}new 8.txt: MISSING'),
        ('Checksum_Helper', logging.WARNING,
         f'sub1{os.sep}new_cshd_3.txt: SHA512 FAILED'),
        ('Checksum_Helper', logging.WARNING,
         f'sub2{os.sep}new_cshd_missing.txt: MISSING'),
        ('Checksum_Helper', logging.INFO,
         f'sub3{os.sep}sub2{os.sep}new_cshd2.txt: SHA512 OK'),
    ])

    assert caplog.record_tuples[17:] == [
        ('Checksum_Helper', logging.WARNING,
         f"{root_dir}{os.sep}tt_most_current_{time.strftime('%Y-%m-%d')}.cshd: 2 files with wrong CRCs!"
         ),
        ('Checksum_Helper', logging.WARNING,
         f"{root_dir}{os.sep}tt_most_current_{time.strftime('%Y-%m-%d')}.cshd: 4 missing files!"
         ),
    ]

    # ------------ 1 wrong crc, 1 missing, HashFile, md5 filtered  ----------
    root_dir = test_verify_root
    # hash_filename_filter literally only filters out the hashfile if a str of
    # hash_filename_filter is in the name of the file without the extension
    a = Args(root_dir=root_dir,
             discover_hash_files_depth=-1,
             hash_filename_filter=("*.md5", ),
             filter=[
                 "*new* ?.txt",
             ])

    caplog.clear()
    _cl_verify_filter(a)
    assert x_contains_all_y(caplog.record_tuples, [
        ('Checksum_Helper', logging.WARNING,
         "Found reference beyond the hash file's root dir in file: '%s'. "
         "Consider moving/copying the file using ChecksumHelper move/copy "
         "to the path that is the most common denominator!" %
         os.path.join(root_dir, "sub3", "sub2", "sub3_sub2.sha512")),
        ('Checksum_Helper', logging.WARNING, f'new 3.txt: SHA512 FAILED'),
        ('Checksum_Helper', logging.INFO, f'new 4.txt: SHA512 OK'),
        ('Checksum_Helper', logging.INFO, f'sub1{os.sep}new 2.txt: SHA512 OK'),
        ('Checksum_Helper', logging.INFO, f'sub1{os.sep}new 3.txt: SHA512 OK'),
        ('Checksum_Helper', logging.INFO, f'sub1{os.sep}new 4.txt: SHA512 OK'),
        ('Checksum_Helper', logging.INFO,
         f'sub1{os.sep}sub2{os.sep}new 2.txt: SHA512 OK'),
        ('Checksum_Helper', logging.INFO,
         f'sub1{os.sep}sub2{os.sep}new 3.txt: SHA512 OK'),
        ('Checksum_Helper', logging.INFO,
         f'sub1{os.sep}sub2{os.sep}new 4.txt: SHA512 OK'),
        ('Checksum_Helper', logging.WARNING,
         f'sub1{os.sep}sub2{os.sep}new 8.txt: MISSING'),
    ])

    assert caplog.record_tuples[10:] == [
        ('Checksum_Helper', logging.WARNING,
         f"{root_dir}{os.sep}tt_most_current_{time.strftime('%Y-%m-%d')}.cshd: 1 files with wrong CRCs!"
         ),
        ('Checksum_Helper', logging.WARNING,
         f"{root_dir}{os.sep}tt_most_current_{time.strftime('%Y-%m-%d')}.cshd: 1 missing files!"
         ),
    ]
コード例 #16
0
ファイル: test_mksummery.py プロジェクト: labinxu/wcode
 def testDocx(self):
     mksummery.INPUT_ARGS = Args()
     util_docx.write('../out/testdocx.doc')
コード例 #17
0
def test_verify_hfile(caplog):
    test_verify_root = os.path.join(TESTS_DIR, "test_verify_files", "tt")
    # caplog.set_level sets on root logger by default which is somehow not the logger setup by
    # checksum_helper so specify our logger in the kw param
    caplog.set_level(logging.INFO, logger='Checksum_Helper')
    # ------------ 1 wrong crc, no missing ----------
    hfile_path = os.path.join(test_verify_root, "sub3", "sub2",
                              "sub3_sub2.sha512")
    a = Args(hash_file_name=[hfile_path])
    starting_cwd = os.getcwd()

    caplog.clear()
    # files_total, nr_matches, nr_missing, nr_crc_errors
    assert _cl_verify_hfile(a) == (3, 2, 0, 1)
    # cwd hasn't changed
    assert starting_cwd == os.getcwd()
    assert x_contains_all_y(caplog.record_tuples, [
        ('Checksum_Helper', logging.WARNING,
         "Found reference beyond the hash file's root dir in file: '%s'. "
         "Consider moving/copying the file using ChecksumHelper move/copy "
         "to the path that is the most common denominator!" %
         os.path.join(test_verify_root, "sub3", "sub2", "sub3_sub2.sha512")),
        ('Checksum_Helper', logging.INFO, f'..{os.sep}file1.txt: SHA512 OK'),
        ('Checksum_Helper', logging.WARNING,
         f'..{os.sep}sub1{os.sep}file1.txt: SHA512 FAILED'),
        ('Checksum_Helper', logging.INFO, f'file1.txt: SHA512 OK'),
    ])
    assert caplog.record_tuples[4:] == [
        ('Checksum_Helper', logging.WARNING,
         f'{test_verify_root}{os.sep}sub3{os.sep}sub2{os.sep}sub3_sub2.sha512: 1 files with wrong CRCs!'
         ),
        ('Checksum_Helper', logging.INFO,
         f'{test_verify_root}{os.sep}sub3{os.sep}sub2{os.sep}sub3_sub2.sha512: No missing files!'
         ),
    ]

    # ------------ all matching, 1 missing ----------
    hfile_path = os.path.join(test_verify_root, "sub1", "sub2",
                              "sub2_1miss.sha512")
    a = Args(hash_file_name=[hfile_path])
    starting_cwd = os.getcwd()

    caplog.clear()
    assert _cl_verify_hfile(a) == (3, 2, 1, 0)
    # cwd hasn't changed
    assert starting_cwd == os.getcwd()
    assert x_contains_all_y(caplog.record_tuples, [
        ('Checksum_Helper', logging.INFO, f'new 2.txt: SHA512 OK'),
        ('Checksum_Helper', logging.WARNING, f'new 8.txt: MISSING'),
        ('Checksum_Helper', logging.INFO, f'new 4.txt: SHA512 OK'),
    ])
    assert caplog.record_tuples[3:] == [
        ('Checksum_Helper', logging.INFO,
         f'{test_verify_root}{os.sep}sub1{os.sep}sub2{os.sep}sub2_1miss.sha512: All files matching their hashes!'
         ),
        ('Checksum_Helper', logging.WARNING,
         f'{test_verify_root}{os.sep}sub1{os.sep}sub2{os.sep}sub2_1miss.sha512: 1 missing files!'
         ),
    ]

    # ----------- no missing all matching ----------
    hfile_path = os.path.join(test_verify_root, "sub1", "sub2", "sub2.sha512")
    a = Args(hash_file_name=[hfile_path])
    starting_cwd = os.getcwd()

    caplog.clear()
    # caplog.set_level sets on root logger by default which is somehow not the logger setup by
    # checksum_helper so specify our logger in the kw param
    assert _cl_verify_hfile(a) == (3, 3, 0, 0)
    # cwd hasn't changed
    assert starting_cwd == os.getcwd()
    assert x_contains_all_y(caplog.record_tuples, [
        ('Checksum_Helper', logging.INFO, f'new 2.txt: SHA512 OK'),
        ('Checksum_Helper', logging.INFO, f'new 3.txt: SHA512 OK'),
        ('Checksum_Helper', logging.INFO, f'new 4.txt: SHA512 OK'),
    ])
    assert caplog.record_tuples[-1] == (
        'Checksum_Helper', logging.INFO,
        f'{test_verify_root}{os.sep}sub1{os.sep}sub2{os.sep}sub2.sha512: No missing files and all files matching their hashes'
    )

    # ----------- 2 missing 2 crc err ----------
    hfile_path = os.path.join(test_verify_root, "sub1+2_n3+4.sha512")
    a = Args(hash_file_name=[hfile_path])
    starting_cwd = os.getcwd()

    caplog.clear()
    # caplog.set_level sets on root logger by default which is somehow not the logger setup by
    # checksum_helper so specify our logger in the kw param
    assert _cl_verify_hfile(a) == (11, 7, 2, 2)
    # cwd hasn't changed
    assert starting_cwd == os.getcwd()
    assert x_contains_all_y(caplog.record_tuples, [
        ('Checksum_Helper', logging.WARNING, f'new 3.txt: SHA512 FAILED'),
        ('Checksum_Helper', logging.INFO, f'new 4.txt: SHA512 OK'),
        ('Checksum_Helper', logging.INFO, f'sub1{os.sep}new 2.txt: SHA512 OK'),
        ('Checksum_Helper', logging.INFO, f'sub1{os.sep}new 3.txt: SHA512 OK'),
        ('Checksum_Helper', logging.INFO, f'sub1{os.sep}new 4.txt: SHA512 OK'),
        ('Checksum_Helper', logging.INFO,
         f'sub1{os.sep}sub2{os.sep}new 2.txt: SHA512 OK'),
        ('Checksum_Helper', logging.INFO,
         f'sub1{os.sep}sub2{os.sep}new 3.txt: SHA512 OK'),
        ('Checksum_Helper', logging.WARNING,
         f'sub1{os.sep}sub2{os.sep}new 4.txt: SHA512 FAILED'),
        ('Checksum_Helper', logging.INFO,
         f'sub2{os.sep}sub1{os.sep}file1.txt: SHA512 OK'),
        ('Checksum_Helper', logging.WARNING,
         f'sub5{os.sep}sub1{os.sep}file1.txt: MISSING'),
        ('Checksum_Helper', logging.WARNING,
         f'sub6{os.sep}file1.txt: MISSING'),
    ])
    assert caplog.record_tuples[11:] == [
        ('Checksum_Helper', logging.WARNING,
         f'{test_verify_root}{os.sep}sub1+2_n3+4.sha512: 2 files with wrong CRCs!'
         ),
        ('Checksum_Helper', logging.WARNING,
         f'{test_verify_root}{os.sep}sub1+2_n3+4.sha512: 2 missing files!'),
    ]
コード例 #18
0
            pickle.dump([self.words2ids, self.docs2ids, \
                         self.ids2docs, self.dft, self.matrix], f)

    @property
    def N(self):
        """ Number of documents. """
        return len(self.ids2docs)


if __name__ == '__main__':
    import argparse
    import os
    from utils import Args
    parser = argparse.ArgumentParser()
    parser.add_argument('--r',
                        default='custom',
                        type=str,
                        help='id of the run')
    parser.add_argument('--pc',
                        type=int,
                        default=None,
                        help='__ for development: pc ID')
    args_inline = parser.parse_args()

    # all arguments for the run
    args = Args(args_inline)

    # im = IncidenceMatrix.load(args)
    im = IncidenceMatrix.create(args)
    im.save(args)
コード例 #19
0
from torch.utils.data import DataLoader
import torch
import torch.nn as nn

import numpy as np
from sklearn.utils import shuffle
from utils import Args, EarlyStopping_unlearning
from losses.confusion_loss import confusion_loss
from losses.DANN_loss import DANN_loss_three_classes
import torch.optim as optim
from train_utils import train_unlearn_threedatasets, val_unlearn_threedatasets, train_encoder_unlearn_threedatasets, val_encoder_unlearn_threedatasets
import sys

########################################################################################################################
# Create an args class
args = Args()
args.channels_first = True
args.epochs = 300
args.batch_size = 16
args.diff_model_flag = False
args.alpha = 1
args.patience = 50
args.learning_rate = 1e-4
args.beta = 10
args.epoch_stage_1 = 100
args.epoch_reached = 1

LOAD_PATH_ENCODER = None
LOAD_PATH_REGRESSOR = None
LOAD_PATH_DOMAIN = None
コード例 #20
0
if __name__ == '__main__':
    import argparse
    from logger import *
    import subprocess
    parser = argparse.ArgumentParser()
    parser.add_argument('-d', default='documents_cs.lst', type=str)
    parser.add_argument('-q', type=str, default='topics-train_cs.xml', help='Topics list file')
    parser.add_argument('-r', type=str, default='run-1_cs', help='Run name')
    parser.add_argument('-document_path', type=str, default='documents_cs', help='Run name')
    parser.add_argument('-f', type=int)
    parser.add_argument('-t', type=int)
    args_inline = parser.parse_args()


    args = Args(args_inline)
    # args._data['o'] = 'results.nosync/res_{}-{}-{}.res'.format(args.lang, args.r, str(9999999))
    if args.lang == 'english':
        qrels = 'A1/qrels-train_en.txt'
    else:
        qrels = 'A1/qrels-train_cs.txt'


    retrieval = Retrieval(args)
    retrieval.search(B=B, K1=K1)

    # THIS CODE IS ONLY FOR TRAINING

    tags = ['cs', 'lemma2', 'qexp3']
    # parameter search
    for experiment in range(args.f, args.t):
コード例 #21
0
    import subprocess
    from logger import *
    import regex as re
    parser = argparse.ArgumentParser()
    parser.add_argument('--pc', type=int, default=None, help='__ for development: pc ID')
    parser.add_argument('--start', type=int)
    parser.add_argument('--end', type=int)
    parser.add_argument('--tag', type=str, default=None, help='tag')
    parser.add_argument('-r', type=str, default='run-1_cs', help='Run name')
    parser.add_argument('-lang', type=str, default='english', help='Language')
    args_inline = parser.parse_args()

    args_inline = parser.parse_args()

    # all arguments for the run
    args = Args(args_inline, False, path=None)

    run_im_format = 'x'
    im_file = 'nic'
    tags = [args.lang] if args.tag is None else [args.lang, args.tag]
    for im_version in range(args_inline.start, args_inline.end+1):
        run_format = im_version
        for arg_file_path in os.listdir(os.path.join('args.nosync', str(im_version))):
            path = os.path.join('args.nosync', str(im_version), arg_file_path)
            args = Args(args_inline, True, path=path)
            logger = NeptuneLogger.new_experiment(tags, args)
            # logger.log_hyperparams(args)
            logger.log_status('started')
            with open(path, 'r') as f:
                logger.log_text('args', f.read())
コード例 #22
0
                    action="store_true",
                    help="test on a smaller dataset (first 1000 samples)")

line_args = parser.parse_args()
if line_args.stop_at < line_args.n_folds:
    print(
        "*****************\n\n\tWARNING:\n\n\tthe code is going to run on {} out of {} folds!\n\n*****************"
        .format(line_args.stop_at, line_args.n_folds))

small = line_args.small
# All important paths and constants

folds_number = line_args.n_folds
stop_after = line_args.stop_at

args = Args()
args.do_train = line_args.train
args.do_eval = line_args.test
args.do_results = line_args.results
args.use_cuda = not line_args.nocuda
args.small = line_args.small
# aggregation_level = "Chapter"
# aggregation_level = "Block"
# aggregation_level = "Category"
# aggregation_level = "Leaf"
aggregation_level = line_args.level

main_dir = "/mnt/HDD/bportelli/lab_avanzato"

original_data_path = "/mnt/HDD/bportelli/lab_avanzato/beatrice.pkl"
コード例 #23
0
import json
import sys
from vae import Vae
from utils import get_tokenizer, Args, get_dataloader, get_logger, get_callbacks
import pytorch_lightning as pl
import torch

if __name__ == "__main__":

    with open(sys.argv[1]) as f:
        config = json.load(f)
    args = Args(**config)
    args.gpus = torch.cuda.device_count()
    args.distributed_backend = 'ddp' if args.gpus > 1 else None

    tokenizer = get_tokenizer(args.vocab_type)
    args.vocab_size = tokenizer.vocab_size

    train_dl = get_dataloader(args, tokenizer, type='train')
    model = Vae(args)

    logger = get_logger(args)

    callbacks = get_callbacks(args)
    trainer = pl.Trainer(max_steps=args.max_steps,
                         gpus=args.gpus,
                         logger=logger,
                         log_every_n_steps=args.log_steps,
                         callbacks=callbacks,
                         distributed_backend=args.distributed_backend)