コード例 #1
0
def run():
    set_logger()
    args = get_args()
    set_seed(args.seed)

    preprocessor = Preprocessor.load(args.preprocessor)

    valid_data_set = DataSet(is_training=False, preprocessor=preprocessor)
    valid_data_set.input_data(args.inference_data)
    valid_data_loader = DataLoader(valid_data_set, args.mb_size)

    # Set up Model
    with open(args.model_params, 'r') as f:
        model_params = json.load(f)
    model = get_model(model_params)
    model.load_state_dict(torch.load(args.model_weight))

    decoding_params = get_decoding_params(args, preprocessor)

    inference_encoder_decoder(
        inference_data_loader=valid_data_loader,
        model=model,
        preprocessor=preprocessor,
        decoding_params=decoding_params,
        inference_csv_path=args.inference_csv,
    )
コード例 #2
0
def run():
    set_logger()
    args = get_args()
    set_seed(args.seed)

    logger.info('Load Wikipedia articles in Japanese.')

    data_set = WikipediaDataSet(args.input_dir, args.cache_dir)

    logger.info('Train gensim Word2Vec model.')
    w2v_model = Word2Vec(window=args.window,
                         size=args.size,
                         negative=args.negative,
                         ns_exponent=args.ns_exponent,
                         min_count=args.min_count,
                         alpha=args.alpha,
                         min_alpha=args.min_alpha,
                         iter=args.epochs,
                         workers=args.workers,
                         seed=args.seed)
    w2v_model.build_vocab(list(data_set.get_text()))
    w2v_model.train(list(data_set.get_text()),
                    total_examples=len(data_set),
                    epochs=args.epochs)

    if args.model_name_to_save:
        logger.info('Save gensim Word2Vec model.')
        with open(args.model_name_to_save, 'wb') as _:
            dill.dump(w2v_model, _)
コード例 #3
0
def run():
    set_logger()
    args = get_args()
    set_seed(args.seed)

    train_data_set = Cifar10DataSet(args.train_image_npy_path, args.train_label_npy_path)
    train_data_loader = Cifar10DataLoader(train_data_set, args.mb_size, use_augment=True)
    test_data_set = Cifar10DataSet(args.test_image_npy_path, args.test_label_npy_path)
    test_data_loader = Cifar10DataLoader(test_data_set, args.mb_size)

    model_params = get_model_params(args)
    optimizer_params = get_optimizer_params(args)
    lr_scheduler_params = get_lr_scheduler_params(args, train_data_loader)

    output_dir_path = args.output_dir_format.format(date=get_date_str())
    setup_output_dir(output_dir_path, dict(args._get_kwargs()), model_params, optimizer_params) #pylint: disable=protected-access

    # Set up Model and Optimizer
    model = get_model(model_params)
    optimizer = get_torch_optimizer(model.parameters(), optimizer_params)
    lr_scheduler = get_torch_lr_scheduler(optimizer, lr_scheduler_params)

    train_loop(
        train_data_loader=train_data_loader,
        valid_data_loader=test_data_loader,
        model=model,
        optimizer=optimizer,
        lr_scheduler=lr_scheduler,
        epochs=args.epochs,
        output_dir_path=output_dir_path,
        model_name_format=args.model_name_format,
    )
コード例 #4
0
def run():
    set_logger()
    args = get_args()
    set_seed(args.seed)

    train_data_set = DataSet(is_training=True)
    train_data_set.input_data(args.train_data)
    train_data_loader = DataLoader(train_data_set, args.mb_size)
    preprocessor = train_data_set.preprocessor

    valid_data_set = DataSet(is_training=False, preprocessor=preprocessor)
    valid_data_set.input_data(args.valid_data)
    valid_data_loader = DataLoader(valid_data_set, args.mb_size)

    model_params = get_model_params(args)
    optimizer_params = get_optimizer_params(args)
    lr_scheduler_params = get_lr_scheduler_params(args, train_data_loader)
    decoding_params = get_decoding_params(args, preprocessor)
    if args.lang == 'ja_to_en':
        model_params['encoder_vocab_count'] = train_data_set.ja_vocab_count
        model_params['decoder_vocab_count'] = train_data_set.en_vocab_count
    elif args.lang == 'en_to_ja':
        model_params['encoder_vocab_count'] = train_data_set.en_vocab_count
        model_params['decoder_vocab_count'] = train_data_set.ja_vocab_count

    output_dir_path = args.output_dir_format.format(date=get_date_str())
    setup_output_dir(
        output_dir_path, dict(args._get_kwargs()), #pylint: disable=protected-access
        model_params, optimizer_params, decoding_params)
    preprocessor.save(os.path.join(output_dir_path, args.preprocessor))

    # Set up Model and Optimizer
    if args.model_params:
        with open(args.model_params, 'r') as f:
            model_params = json.load(f)
    model = get_model(model_params)
    if args.initial_weight:
        model.load_state_dict(torch.load(args.initial_weight))
    optimizer = get_torch_optimizer(model.parameters(), optimizer_params)
    lr_scheduler = get_torch_lr_scheduler(optimizer, lr_scheduler_params)

    train_loop(
        train_data_loader=train_data_loader,
        valid_data_loader=valid_data_loader,
        model=model,
        optimizer=optimizer,
        lr_scheduler=lr_scheduler,
        preprocessor=preprocessor,
        decoding_params=decoding_params,
        epochs=args.epochs,
        output_dir_path=output_dir_path,
        model_name_format=args.model_name_format,
    )
コード例 #5
0
def run():
    set_logger()
    args = get_args()
    set_seed(args.seed)

    figure_path = os.path.join(args.output_dir, args.time_step_plot)
    setup_output_dir(args.output_dir, dict(args._get_kwargs())
                     )  # pylint: disable=protected-access

    env = MazeEnvironment(args.maze)
    if args.algorithm == 'valueiter':
        value_iteration(
            env,
            gamma=args.gamma,
            iter_count=args.iter_count,
            render=args.render,
            figure_path=figure_path,
        )
    elif args.algorithm == 'sarsalambda':
        q_value = env.get_initial_q_value()
        # warm up
        sarsa_lambda(
            q_value,
            env,
            alpha=args.alpha,
            gamma=args.gamma,
            epsilon=args.epsilon,
            lambda_value=args.lambda_value,
            iter_count=10,
            render=False,
            figure_path=figure_path,
        )
        sarsa_lambda(
            q_value,
            env,
            alpha=args.alpha,
            gamma=args.gamma,
            epsilon=0,
            lambda_value=args.lambda_value,
            iter_count=100,
            render=args.render,
            figure_path=figure_path,
        )
    env.close()
コード例 #6
0
ファイル: train_my_w2v.py プロジェクト: Yuki-Wada/mltools
def run():
    set_logger()
    args = get_args()
    set_seed(args.seed)

    logger.info('Load Wikipedia articles in Japanese.')

    data_set = WikipediaDataSet(args.input_dir, args.cache_dir)
    data_loader = WikipediaDataLoader(data_set, args.mb_size * args.workers)

    dictionary: Dictionary = data_set.dictionary
    dictionary.filter_extremes(no_below=args.min_count, no_above=0.999)
    w2v_model = MyWord2Vec(dictionary=dictionary,
                           window=args.window,
                           size=args.size,
                           negative=args.negative,
                           ns_exponent=args.ns_exponent,
                           alpha=args.alpha,
                           workers=args.workers)

    logger.info('Train my Word2Vec model.')
    for epoch in range(args.epochs):
        logger.info('Epoch: %d', epoch + 1)

        w2v_model.wc = 0
        with tqdm(total=len(data_set), desc="Train Word2Vec") as pbar:
            for mb_texts in data_loader.get_iter():
                mb_indexed_texts = [
                    dictionary.doc2idx(text) for text in mb_texts
                ]
                w2v_model.train(mb_indexed_texts)

                w2v_model.lr = \
                    args.alpha - ((args.alpha - args.min_alpha) * (epoch + 1) / args.epochs)
                pbar.update(len(mb_indexed_texts))

        if args.model_name_to_save:
            logger.info('Save my Word2Vec model.')
            with open(args.model_name_to_save, 'wb') as _:
                dill.dump(w2v_model, _)
コード例 #7
0
def run():
    set_logger()
    args = get_args()
    set_seed(args.seed)

    figure_path = os.path.join(args.output_dir, args.time_step_plot)
    setup_output_dir(args.output_dir, dict(args._get_kwargs()))  #pylint: disable=protected-access

    state_converter = CartPoleStateConverter(epsilon=args.epsilon)

    gym.envs.registration.register(
        id='CartPole-v2',
        entry_point='gym.envs.classic_control:CartPoleEnv',
        max_episode_steps=args.max_steps,
        reward_threshold=int(args.max_steps * 0.975),
    )

    env = gym.make('CartPole-v2')
    if args.algorithm == 'montecarlo':
        monte_carlo(
            env,
            state_converter,
            max_steps=args.max_steps,
            first_visit=args.first_visit,
            gamma=args.gamma,
            episode_count=args.episode_count,
            render=args.render,
            figure_path=figure_path,
        )
    elif args.algorithm == 'sarsa':
        sarsa(
            env,
            state_converter,
            max_steps=args.max_steps,
            n_step=args.n_step,
            alpha=args.alpha,
            gamma=args.gamma,
            episode_count=args.episode_count,
            render=args.render,
            figure_path=figure_path,
        )
    elif args.algorithm == 'qlearning':
        q_learning(
            env,
            state_converter,
            max_steps=args.max_steps,
            n_step=args.n_step,
            alpha=args.alpha,
            gamma=args.gamma,
            episode_count=args.episode_count,
            render=args.render,
            figure_path=figure_path,
        )
    elif args.algorithm == 'dynaq':
        dyna_q(
            env,
            state_converter,
            max_steps=args.max_steps,
            n_step=args.n_step,
            alpha=args.alpha,
            gamma=args.gamma,
            episode_count=args.episode_count,
            render=args.render,
            figure_path=figure_path,
        )
    elif args.algorithm == 'prioritized':
        prioritized_sweeping(
            env,
            state_converter,
            max_steps=args.max_steps,
            n_step=args.n_step,
            alpha=args.alpha,
            gamma=args.gamma,
            episode_count=args.episode_count,
            render=args.render,
            figure_path=figure_path,
        )
    elif args.algorithm == 'sarsalambda':
        sarsa_lambda(
            env,
            state_converter,
            max_steps=args.max_steps,
            alpha=args.alpha,
            gamma=args.gamma,
            lambda_value=args.lambda_value,
            episode_count=args.episode_count,
            render=args.render,
            figure_path=figure_path,
        )
    env.close()