Exemplo n.º 1
0
def get_player(directory=None,
               files_list=None,
               viz=False,
               task='play',
               saveGif=False,
               saveVideo=False,
               agents=2,
               reward_strategy=1):
    # in atari paper, max_num_frames = 30000
    env = MedicalPlayer(directory=directory,
                        screen_dims=IMAGE_SIZE,
                        viz=viz,
                        saveGif=saveGif,
                        saveVideo=saveVideo,
                        task=task,
                        files_list=files_list,
                        agents=agents,
                        max_num_frames=1500,
                        reward_strategy=reward_strategy)
    if (task != 'train'):
        # in training, env will be decorated by ExpReplay, and history
        # is taken care of in expreplay buffer
        # otherwise, FrameStack modifies self.step to save observations into a queue
        env = FrameStack(env, FRAME_HISTORY, agents=agents)
    return env
Exemplo n.º 2
0
def get_player(directory=None,
               files_list=None,
               landmark_ids=None,
               viz=False,
               task="play",
               file_type="brain",
               saveGif=False,
               saveVideo=False,
               multiscale=True,
               history_length=20,
               agents=1,
               logger=None):
    env = MedicalPlayer(directory=directory,
                        screen_dims=IMAGE_SIZE,
                        viz=viz,
                        saveGif=saveGif,
                        saveVideo=saveVideo,
                        task=task,
                        files_list=files_list,
                        file_type=file_type,
                        landmark_ids=landmark_ids,
                        history_length=history_length,
                        multiscale=multiscale,
                        agents=agents,
                        logger=logger)
    if task != "train":
        # in training, env will be decorated by ExpReplay, and history
        # is taken care of in expreplay buffer
        # otherwise, FrameStack modifies self.step to save observations into a
        # queue
        env = FrameStack(env, FRAME_HISTORY, agents)
    return env
Exemplo n.º 3
0
def get_player(
    directory=None,
    files_list=None,
    viz=False,
    task="play",
    saveGif=False,
    saveVideo=False,
    agents=2,
    fiducials=None,
    infDir="../inference",
):
    # in atari paper, max_num_frames = 30000
    env = MedicalPlayer(
        directory=directory,
        screen_dims=IMAGE_SIZE,
        viz=viz,
        saveGif=saveGif,
        saveVideo=saveVideo,
        task=task,
        files_list=files_list,
        agents=agents,
        max_num_frames=1500,
        fiducials=fiducials,
        infDir=infDir,
    )
    if task != "train":
        # in training, env will be decorated by ExpReplay, and history
        # is taken care of in expreplay buffer
        # otherwise, FrameStack modifies self.step to save observations into a queue
        env = FrameStack(env, FRAME_HISTORY, agents=agents)
    return env
Exemplo n.º 4
0
    if args.task == 'play':
        error_message = """Wrong input files {} for {} task - should be 1 \'images.txt\' """.format(
            len(args.files), args.task)
        assert len(args.files) == 1
    else:
        error_message = """Wrong input files {} for {} task - should be 2 [\'images.txt\', \'landmarks.txt\'] """.format(
            len(args.files), args.task)
        assert len(args.files) == 2, (error_message)

    args.agents = int(args.agents)

    METHOD = args.algo
    # load files into env to set num_actions, num_validation_files
    init_player = MedicalPlayer(files_list=args.files,
                                screen_dims=IMAGE_SIZE,
                                task='train',
                                agents=args.agents,
                                reward_strategy=args.reward_strategy)
    NUM_ACTIONS = init_player.action_space.n
    num_files = init_player.files.num_files

    ##########################################################
    #initialize states and Qvalues for the various agents
    state_names = []
    qvalue_names = []
    for i in range(0, args.agents):
        state_names.append('state_{}'.format(i))
        qvalue_names.append('Qvalue_{}'.format(i))

############################################################
Exemplo n.º 5
0
        assert len(args.files) == 1, (error_message)
    else:
        error_message = f"""Wrong input files {len(args.files)} for
                            {args.task} task - should be 2 [\'images.txt\',
                            \'landmarks.txt\'] """
        assert len(args.files) == 2, (error_message)

    logger = Logger(args.logDir, args.write, args.save_freq)

    # load files into env to set num_actions, num_validation_files
    # TODO: is this necessary?
    init_player = MedicalPlayer(
        files_list=args.files,
        file_type=args.file_type,
        landmark_ids=args.landmarks,
        screen_dims=IMAGE_SIZE,
        # TODO: why is this always play?
        task='play',
        agents=agents,
        logger=logger)
    NUM_ACTIONS = init_player.action_space.n

    if args.task != 'train':
        # TODO: refactor DQN to not have to create both a q_network and
        # target_network
        dqn = DQN(agents,
                  frame_history=FRAME_HISTORY,
                  logger=logger,
                  type=args.model_name)
        model = dqn.q_network
        model.load_state_dict(torch.load(args.load, map_location=model.device))
Exemplo n.º 6
0
        )
        assert len(args.files) == 1
    else:
        error_message = """Wrong input files {} for {} task - should be 2 [\'images.txt\', \'landmarks.txt\'] """.format(
            len(args.files), args.task
        )
        assert len(args.files) == 2, error_message

    # args.agents=int(args.agents)

    METHOD = args.algo
    # load files into env to set num_actions, num_validation_files
    init_player = MedicalPlayer(
        files_list=args.files,
        screen_dims=IMAGE_SIZE,
        task=args.task,
        agents=args.agents,
        fiducials=args.fiducials,
        infDir=args.inferDir,
    )
    NUM_ACTIONS = init_player.action_space.n
    num_files = init_player.files.num_files

    ##########################################################
    # initialize states and Qvalues for the various agents
    state_names = []
    qvalue_names = []
    for i in range(0, args.agents):
        state_names.append("state_{}".format(i))
        qvalue_names.append("Qvalue_{}".format(i))

    ############################################################
Exemplo n.º 7
0
    # check input files
    if args.task == 'play':
        error_message = """Wrong input files {} for {} task - should be 1 \'images.txt\' """.format(
            len(args.files), args.task)
        assert len(args.files) == 1, (error_message)
    else:
        error_message = """Wrong input files {} for {} task - should be 2 [\'images.txt\', \'landmarks.txt\'] """.format(
            len(args.files), args.task)
        assert len(args.files) == 2, (error_message)

    METHOD = args.algo
    # load files into env to set num_actions, num_validation_files
    init_player = MedicalPlayer(
        files_list=args.files,
        screen_dims=IMAGE_SIZE,
        # TODO: why is this always play?
        task='play',
        agents=args.agents,
        reward_strategy=args.reward_strategy)
    NUM_ACTIONS = init_player.action_space.n
    num_files = init_player.files.num_files

    if args.task != 'train':
        assert args.load is not None
        pred = OfflinePredictor(
            PredictConfig(model=Model(),
                          session_init=get_model_loader(args.load),
                          input_names=['state'],
                          output_names=['Qvalue']))
        # demo pretrained model one episode at a time
        if args.task == 'play' or args.task == 'eval':
Exemplo n.º 8
0
    # check input files
    if args.task == 'play':
        error_message = """Wrong input files {} for {} task - should be 1 \'images.txt\' """.format(
            len(args.files), args.task)
        assert len(args.files) == 1
    else:  # this is 'train'
        print("args.task:", args.task)
        error_message = """Wrong input files {} for {} task - should be 2 [\'images.txt\', \'landmarks.txt\'] """.format(
            len(args.files), args.task)
        assert len(args.files) == 2, (error_message)

    METHOD = args.algo

    # load files into env to set num_actions, num_validation_files
    init_player = MedicalPlayer(files_list=args.files,
                                screen_dims=IMAGE_SIZE,
                                task='play')

    NUM_ACTIONS = init_player.action_space.n
    num_files = init_player.files.num_files

    if args.task != 'train':
        assert args.load is not None
        pred = OfflinePredictor(
            PredictConfig(model=Model(),
                          session_init=get_model_loader(args.load),
                          input_names=['state'],
                          output_names=['Qvalue']))

        # demo pretrained model one episode at a time
        if args.task == 'play':
Exemplo n.º 9
0
    # check input files
    if args.task == "play":
        error_message = """Wrong input files {} for {} task - should be 1 \'images.txt\' """.format(
            len(args.files), args.task)
        assert len(args.files) == 1
    else:
        error_message = """Wrong input files {} for {} task - should be 2 [\'images.txt\', \'landmarks.txt\'] """.format(
            len(args.files), args.task)
        assert len(args.files) == 2, error_message

    args.agents = int(args.agents)

    METHOD = args.algo
    # load files into env to set num_actions, num_validation_files
    init_player = MedicalPlayer(files_list=args.files,
                                screen_dims=IMAGE_SIZE,
                                task="train",
                                agents=args.agents)
    NUM_ACTIONS = init_player.action_space.n
    num_files = init_player.files.num_files

    ##########################################################
    # initialize states and Qvalues for the various agents
    state_names = []
    qvalue_names = []
    for i in range(0, args.agents):
        state_names.append("state_{}".format(i))
        qvalue_names.append("Qvalue_{}".format(i))

    ############################################################

    if args.task != "train":
Exemplo n.º 10
0
                        help='save gif image of the game',
                        action='store_true',
                        default=False)
    parser.add_argument('--saveVideo',
                        help='save video of the game',
                        action='store_true',
                        default=False)
    args = parser.parse_args()

    if args.gpu:
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    METHOD = args.algo
    # load files into env to set num_actions, num_validation_files
    init_player = MedicalPlayer(directory=data_dir,
                                files_list=test_list,
                                screen_dims=IMAGE_SIZE)
    NUM_ACTIONS = init_player.action_space.n
    num_files = init_player.files.num_files

    if args.task != 'train':
        assert args.load is not None
        pred = OfflinePredictor(
            PredictConfig(model=Model(),
                          session_init=get_model_loader(args.load),
                          input_names=['state'],
                          output_names=['Qvalue']))
        # demo pretrained model one episode at a time
        if args.task == 'play':
            play_n_episodes(
                get_player(directory=data_dir,