def test_basic(self):
        @dataclass
        class Opt:
            x: int = 42
            y: bool = False

        params = ArgumentParser(Opt).parse_args([])
        self.assertEqual(42, params.x)
        self.assertEqual(False, params.y)
        params = ArgumentParser(Opt).parse_args(["--x=10", "--y"])
        self.assertEqual(10, params.x)
        self.assertEqual(True, params.y)
    def test_nargs_plus(self):
        @dataclass
        class Args:
            name: str
            friends: List[str] = field(metadata=dict(nargs="+"))

        args = ["--name", "Sam", "--friends", "pippin", "Frodo"]
        params = ArgumentParser(Args).parse_args(args)
        self.assertEqual("Sam", params.name)
        self.assertEqual(["pippin", "Frodo"], params.friends)

        args += ["Bilbo"]
        params = ArgumentParser(Args).parse_args(args)
        self.assertEqual("Sam", params.name)
        self.assertEqual(["pippin", "Frodo", "Bilbo"], params.friends)
    def test_type(self):
        @dataclass
        class Options:
            name: str = field(metadata=dict(type=str.title))

        params = ArgumentParser(Options).parse_args(["--name", "john doe"])
        self.assertEqual(params.name, "John Doe")
    def test_choices(self):
        @dataclass
        class Options:
            small_integer: int = field(metadata=dict(choices=[1, 2, 3]))

        params = ArgumentParser(Options).parse_args(["--small-integer", "2"])
        self.assertEqual(params.small_integer, 2)
    def test_choices_negative(self):
        @dataclass
        class Options:
            small_integer: int = field(metadata=dict(choices=[1, 2, 3]))

        with NegativeTestHelper() as helper:
            ArgumentParser(Options).parse_args(["--small-integer", "20"])
        self.assertIsNotNone(helper.exit_status,
                             "Expected an error while parsing")
    def test_positional(self):
        @dataclass
        class Options:
            x: int = field(metadata=dict(args=["-x", "--long-name"]))
            positional: str = field(metadata=dict(args=["positional"]))

        params = ArgumentParser(Options).parse_args(["-x", "0", "POS_VALUE"])
        self.assertEqual(params.x, 0)
        self.assertEqual(params.positional, "POS_VALUE")
    def test_no_defaults_negative(self):
        @dataclass
        class Args:
            num_of_foo: int
            name: str

        with NegativeTestHelper() as helper:
            ArgumentParser(Args).parse_args([])
        self.assertIsNotNone(helper.exit_status,
                             "Expected an error while parsing")
    def test_no_defaults(self):
        @dataclass
        class Args:
            num_of_foo: int
            name: str

        params = ArgumentParser(Args).parse_args(
            ["--num-of-foo=10", "--name", "Sam"])
        self.assertEqual(10, params.num_of_foo)
        self.assertEqual("Sam", params.name)
    def test_default_factory(self):
        @dataclass
        class Parameters:
            cutoff_date: dt.datetime = field(
                default_factory=dt.datetime.now,
                metadata=dict(type=dt.datetime.fromisoformat))

        s_time = dt.datetime.now()
        params = ArgumentParser(Parameters).parse_args([])
        e_time = dt.datetime.now()
        self.assertGreaterEqual(params.cutoff_date, s_time)
        self.assertLessEqual(params.cutoff_date, e_time)

        s_time = dt.datetime.now()
        params = ArgumentParser(Parameters).parse_args([])
        e_time = dt.datetime.now()
        self.assertGreaterEqual(params.cutoff_date, s_time)
        self.assertLessEqual(params.cutoff_date, e_time)

        date = dt.datetime(2000, 1, 1)
        params = ArgumentParser(Parameters).parse_args(
            ["--cutoff-date", date.isoformat()])
        self.assertEqual(params.cutoff_date, date)
    def test_default_factory_2(self):
        factory_calls = 0

        def factory_func():
            nonlocal factory_calls
            factory_calls += 1
            return f"Default Message: {factory_calls}"

        @dataclass
        class Parameters:
            message: str = field(default_factory=factory_func)

        params = ArgumentParser(Parameters).parse_args([])
        self.assertEqual(params.message, "Default Message: 1")
        self.assertEqual(factory_calls, 1)

        params = ArgumentParser(Parameters).parse_args(
            ["--message", "User message"])
        self.assertEqual(params.message, "User message")
        self.assertEqual(factory_calls, 1)

        params = ArgumentParser(Parameters).parse_args([])
        self.assertEqual(params.message, "Default Message: 2")
        self.assertEqual(factory_calls, 2)
예제 #11
0
    checkpoint_fp: str = './data/results/R2.2-{ietk_method_name}/model_checkpoints/epoch_best.pth'

    #  get dataset
    dset = D.RITE(use_train_set=False,
                  getitem_transform=D.RITE.as_tensor(['av', 'vessel'],
                                                     return_numpy_array=True))

    n_imgs: int = 10


def main(ns: params):
    os.makedirs(ns.img_save_dir, exist_ok=True)
    mdl = load_model(ns.checkpoint_fp.format(**ns.__dict__), ns.device)
    for n, tup in enumerate(yield_imgs(mdl, ns)):
        figs = plot_rite_segmentation(*tup, ns)
        for k, fig in figs.items():
            save_fp = f'{ns.img_save_dir}/rite-seg-{n}-{k}.png'
            fig.savefig(save_fp, bbox_inches='tight')
        # iterate through images and show correlation?  can we just train model on test set?
        #  perf = (yhat > 0.5 == y).sum((1,2,3))
        #  print(perf)  # vector of 2 scalars
        # should see  perf[0] <= perf[1]
        if n >= ns.n_imgs:
            break


if __name__ == "__main__":
    ns = ArgumentParser(params).parse_args()
    print(ns)
    main(ns)
예제 #12
0
def start(config_params: Config):
    start_time = time.time()
    log = ImportData(config_params.event_log_file)
    logging.info(log.trace_list)
    logging.info(log.unique_events)
    logging.info(log.event_map)
    population = InitialPopulation(
        log.unique_events, config_params.initial_population_size
    )
    best_tree = utility.run(
        population.trees, log.unique_events, log.trace_list, config_params
    )
    logging.info(f"Execution time: {time.time() - start_time}")
    logging.info(
        f"Tree: {best_tree} Replay fitness: {best_tree.replay_fitness} Precision: {best_tree.precision} Simplicity: {best_tree.simplicity} Generalization: {best_tree.generalization} Fitness: {best_tree.fitness}"
    )
    for k, v in config_params.__dict__.items():
        logging.info(f"{k}: {v}")
    logging.info("Tree class values")
    for k, v in best_tree.__dict__.items():
        logging.info(f"{k}: {v}")


if __name__ == "__main__":
    parser = ArgumentParser(Config)
    config = parser.parse_args()
    start(config)



예제 #13
0
파일: run.py 프로젝트: dkmiller/tidbits
    output_dir = Path(args.output_directory)
    output_dir.mkdir(exist_ok=True)
    output_path = output_dir / args.output_file_name
    for file_path in json_files:
        log.info(f"Loading {file_path}")
        with file_path.open("r") as fp:
            lines = fp.readlines()
            log.info(f"Found {len(lines)} lines")
            for line in lines:
                parsed_line = json.loads(line)
                source_value = select_and_merge_jsonpaths(
                    parsed_line, args.source_jsonpaths)
                target_value = select_and_merge_jsonpaths(
                    parsed_line, [args.target_jsonpath])
                output_object = {
                    args.source_key: source_value,
                    args.target_key: target_value,
                }
                output_line = json.dumps(output_object)
                with output_path.open("a") as out_fp:
                    out_fp.write(output_line)
                    out_fp.write(os.linesep)


if __name__ == "__main__":
    parser = ArgumentParser(Args)
    args = parser.parse_args()
    logging.basicConfig(level="INFO")
    main(args)
예제 #14
0
    tokenized_dataset = coalesced_dataset.map(preprocess_function,
                                              batched=True)

    data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

    model = AutoModelForSequenceClassification.from_pretrained(args.model,
                                                               num_labels=2)

    # TODO: separate train and eval inputs.
    trainer = Trainer(
        model=model,
        args=train_args,
        train_dataset=tokenized_dataset["train"],
        eval_dataset=tokenized_dataset["train"],
        tokenizer=tokenizer,
        data_collator=data_collator,
    )

    trainer.train()


if __name__ == "__main__":
    logging.basicConfig(level="INFO")
    parser = HfArgumentParser(TrainingArguments)
    (train_args, unknown) = parser.parse_args_into_dataclasses(
        return_remaining_strings=True)
    parser = ArgumentParser(Args)
    args = parser.parse_args(unknown)

    main(train_args, args)
예제 #15
0
class ParseOptions:
    M1_dim: Tuple[int, int] = field(metadata=dict(args=["-M1_dim"]),
                                    default=(2, 2))
    M2_dim: Tuple[int, int] = field(metadata=dict(args=["-M2_dim"]),
                                    default=(2, 2))
    hidden_layers: str = field(metadata=dict(args=["-hiddens"]), default=None)
    log_dir: str = field(metadata=dict(args=["-log-dir"]), default=None)
    learning_rate: float = field(metadata=dict(args=["-learning_rate"]),
                                 default=1e-3)
    buffer_size: int = field(metadata=dict(args=["-buf_size"]), default=1000)
    batch_size: int = field(metadata=dict(args=["-batch_size"]), default=32)
    loss: str = field(metadata=dict(args=["-loss"]), default="mse")
    optimizer: str = field(metadata=dict(args=["-optimizer"]), default="adam")
    activation: str = field(metadata=dict(args=["-activation"]),
                            default="ReLU")
    layer: str = field(metadata=dict(args=["-layer"]), default="affine")


if __name__ == "__main__":
    parser = ArgumentParser(ParseOptions)
    print(parser.parse_args())
    args = parser.parse_args()
    hiddens = args.hidden_layers[1:len(args.hidden_layers) - 1]
    hiddens = hiddens.split(',')
    hiddens = [int(i) for i in hiddens]
    args.hidden_layers = hiddens
    subprocess.Popen(["tensorboard", "--logdir", args.log_dir])
    webbrowser.open("127.0.0.1:6006")
    trainer = Trainer(**vars(args))
    trainer.train()
예제 #16
0
                                        flip_x=False,
                                        resize_to=img.shape[:2],
                                        crop_to_size=img.shape[:2])
        enhanced_img = (enhanced_img.permute(1, 2, 0).numpy() *
                        255).astype('uint8')

        axs[ax_idx].imshow(enhanced_img)
        axs[ax_idx].set_title(ietk_method_name, fontsize=20)
        axs[ax_idx].axis('off')
    fig.subplots_adjust(wspace=0.02, hspace=0.02, top=0.9, bottom=0.1)
    #  fig.tight_layout()
    fig.savefig(save_fp, bbox_inches='tight')
    return fig


def main(ns):
    os.makedirs(ns.save_fig_dir, exist_ok=True)

    if ns.img_idx is not None:
        img_idxs = [ns.img_idx]
    else:
        img_idxs = np.random.randint(0, len(ns.dset_qualdr), ns.num_imgs)
    for img_idx in img_idxs:
        save_plot(ns, img_idx)
        #  plt.show(block=False)
        #  plt.pause(0)


if __name__ == "__main__":
    main(ArgumentParser(params).parse_args())