Пример #1
0
parser = argparse.ArgumentParser(description='A3C')
parser.add_argument('--env-name',
                    default='SpaceInvaders-v0',
                    help='environment to train on (default: SpaceInvaders-v0)')
parser.add_argument('--logs-path',
                    default="./log.txt",
                    help='path to logs (default: `./log.txt`)')
parser.add_argument(
    '--pretrained-weights',
    default=None,
    help='path to pretrained weights (default: None – evaluate random model)')

if __name__ == '__main__':
    cmd_args = parser.parse_args()
    logging.basicConfig(filename=cmd_args.logs_path, level=logging.INFO)

    env = make_atari(cmd_args.env_name, clip=False)
    model = ActorCritic(env.observation_space.shape[0], env.action_space.n)
    if cmd_args.pretrained_weights is not None:
        model.load_weights(cmd_args.pretrained_weights)
    else:
        print(
            "You have not specified path to model weigths, random plays will be performed"
        )
    model.eval()
    results = record_video(model, env)
    log_message = "evaluated on pretrained weights: {}, results: {}".format(
        cmd_args.pretrained_weights, results)
    print(log_message)
    logging.info(log_message)
Пример #2
0
    default=None,
    help=
    'path to pretrained weights and optimizer params (default: if None – train from scratch)'
)

if __name__ == '__main__':
    cmd_args = parser.parse_args()
    config = Config.fromYamlFile('config.yaml')
    args = config.train
    args.__dict__.update(vars(cmd_args))

    env = make_atari(args.env_name)

    shared_model = ActorCritic(env.observation_space.shape, env.action_space.n)
    if args.pretrained_weights is not None:
        shared_model.load_weights(args.pretrained_weights)
    shared_model.share_memory()

    optimizer = SharedAdam(shared_model.parameters(), lr=args.learning_rate)
    if args.pretrained_weights is not None:
        optimizer.load_params(
            args.pretrained_weights.replace('weights/', 'optimizer_params/'))
    optimizer.share_memory()

    processes = []

    lock = mp.Lock()
    total_steps = Value('i', 0)

    p = mp.Process(target=test_worker,
                   args=(args, shared_model, total_steps, optimizer))