Example #1
0
parser.add_argument('--data_path', help='enter path for training data',
                    type=str)

parser.add_argument('--gpu_id', default="0", help='enter gpu id',
                    type=str,action=check_size(0,10))

parser.add_argument('--max_para_req', default=100, help='enter the max length of paragraph',
                    type=int, action=check_size(30,300))

parser.add_argument('--batch_size_squad',default=16, help='enter the batch size',
                    type=int, action=check_size(1,256))

parser.set_defaults()

args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id

hidden_size = 150
gradient_clip_value = 15
embed_size = 300

params_dict = {}
params_dict['batch_size'] = args.batch_size_squad
params_dict['embed_size'] = 300
params_dict['pad_idx'] = 0
params_dict['hs'] = hidden_size
params_dict['glove_dim'] = 300
params_dict['iter_interval'] = 8000
params_dict['num_iterations'] = 500000
params_dict['ax'] = ax
Example #2
0
    }
    return reduced_results


if __name__ == "__main__":
    parser = NgraphArgparser(
        description='Train deep residual network on cifar10 dataset')
    parser.add_argument(
        '--stage_depth',
        type=int,
        default=2,
        help='depth of each stage (network depth will be 9n+2)')
    parser.add_argument('--use_aeon',
                        action='store_true',
                        help='whether to use aeon dataloader')
    args = parser.parse_args()

    np.random.seed(args.rng_seed)

    # Create the dataloader
    if args.use_aeon:
        from data import make_aeon_loaders
        train_set, valid_set = make_aeon_loaders(args.data_dir,
                                                 args.batch_size,
                                                 args.num_iterations)
    else:
        from ngraph.frontends.neon import ArrayIterator  # noqa
        from ngraph.frontends.neon import CIFAR10  # noqa
        train_data, valid_data = CIFAR10(args.data_dir).load_data()
        train_set = ArrayIterator(train_data,
                                  args.batch_size,
Example #3
0
        environment = DimShuffleWrapper(environment)

    # todo: perhaps these should be defined in the environment itself
    state_axes = ng.make_axes([
        ng.make_axis(environment.observation_space.shape[0], name='C'),
        ng.make_axis(environment.observation_space.shape[1], name='H'),
        ng.make_axis(environment.observation_space.shape[2], name='W'),
    ])

    agent = dqn.Agent(
        state_axes,
        environment.action_space,
        model=model,
        epsilon=dqn.linear_generator(start=1.0, end=0.1, steps=1000000),
        gamma=0.99,
        learning_rate=0.00025,
        memory=dqn.Memory(maxlen=1000000),
        target_network_update_frequency=1000,
        learning_starts=10000,
    )

    rl_loop.rl_loop_train(environment, agent, episodes=200000)


if __name__ == "__main__":
    from ngraph.frontends.neon import NgraphArgparser

    parser = NgraphArgparser()
    parser.parse_args()
    main()