Exemple #1
0
    def __init__(self, args, device, writer=None, writer_counter=None, win_event_counter=None):
        super(SpGcnEnv, self).__init__()

        self.reset()
        self.args = args
        self.device = device
        self.writer = writer
        self.writer_counter = writer_counter
        self.discrete_action_space = False

        if self.args.reward_function == 'fully_supervised':
            self.reward_function = FullySupervisedReward(env=self)
        elif self.args.reward_function == 'sub_graph_dice':
            self.reward_function = SubGraphDiceReward(env=self)
        elif self.args.reward_function == 'defining_rules':
            self.reward_function = HoughCircles(env=self, range_num=[8, 10],
                                                range_rad=[max(self.args.data_shape) // 18,
                                                           max(self.args.data_shape) // 15], min_hough_confidence=0.7)
        elif self.args.reward_function == 'defining_rules_lg':
            self.reward_function = HoughCircles_lg(env=self, range_num=[8, 10],
                                                range_rad=[max(self.args.data_shape) // 18,
                                                           max(self.args.data_shape) // 15], min_hough_confidence=0.7)
        # elif self.args.reward_function == 'focal':
        #     self.reward_function = FocalReward(env=self)
        # elif self.args.reward_function == 'global_sparse':
        #     self.reward_function = GlobalSparseReward(env=self)
        else:
            self.reward_function = UnSupervisedReward(env=self)
Exemple #2
0
    def __init__(self,
                 args,
                 device,
                 writer=None,
                 writer_counter=None,
                 win_event_counter=None):
        super(SpGcnEnv, self).__init__()
        self.stop_quality = 0

        self.reset()
        self.args = args
        self.device = device
        self.writer = writer
        self.writer_counter = writer_counter
        self.win_event_counter = win_event_counter
        self.discrete_action_space = False

        if self.args.reward_function == 'fully_supervised':
            self.reward_function = FullySupervisedReward(env=self)
        elif self.args.reward_function == 'object_level':
            self.reward_function = ObjectLevelReward(env=self)
        elif self.args.reward_function == 'graph_dice':
            self.reward_function = GraphDiceReward(env=self)
        elif self.args.reward_function == 'focal':
            self.reward_function = FocalReward(env=self)
        elif self.args.reward_function == 'global_sparse':
            self.reward_function = GlobalSparseReward(env=self)
        else:
            self.reward_function = UnSupervisedReward(env=self)
Exemple #3
0
    def __init__(self, embedding_net, cfg, device, writer=None, writer_counter=None):
        super(EmbeddingSpaceEnvNodeBased, self).__init__()
        self.embedding_net = embedding_net
        self.reset()
        self.cfg = cfg
        self.device = device
        self.writer = writer
        self.writer_counter = writer_counter
        self.last_final_reward = torch.tensor([0.0])
        self.max_p = torch.nn.MaxPool2d(3, padding=1, stride=1)
        self.step_encoder = TemporalSineEncoding(max_step=cfg.trn.max_episode_length,
                                                 size=cfg.fe.n_embedding_features)

        if self.cfg.trn.reward_function == 'sub_graph_dice':
            self.reward_function = SubGraphDiceReward()
        else:
            self.reward_function = UnSupervisedReward(env=self)

        self.cluster_policy = nagglo.cosineDistNodeAndEdgeWeightedClusterPolicy
Exemple #4
0
    def __init__(self, cfg, device, writer=None, writer_counter=None):
        super(SpGcnEnv, self).__init__()

        self.reset()
        self.cfg = cfg
        self.device = device
        self.writer = writer
        self.writer_counter = writer_counter
        self.discrete_action_space = False
        self.max_p = torch.nn.MaxPool2d(3, padding=1, stride=1)

        if self.cfg.sac.reward_function == 'fully_supervised':
            self.reward_function = FullySupervisedReward(env=self)
        elif self.cfg.sac.reward_function == 'sub_graph_dice':
            self.reward_function = SubGraphDiceReward(env=self)
        elif self.cfg.sac.reward_function == 'defining_rules_edge_based':
            self.reward_function = HoughCircles(
                env=self,
                range_num=[8, 10],
                range_rad=[
                    max(self.cfg.sac.data_shape) // 18,
                    max(self.cfg.sac.data_shape) // 15
                ],
                min_hough_confidence=0.7)
        elif self.cfg.sac.reward_function == 'defining_rules_sp_based':
            self.reward_function = HoughCirclesOnSp(
                env=self,
                range_num=[8, 10],
                range_rad=[
                    max(self.cfg.sac.data_shape) // 18,
                    max(self.cfg.sac.data_shape) // 15
                ],
                min_hough_confidence=0.7)
        elif self.cfg.sac.reward_function == 'defining_rules_lg':
            assert False
        else:
            self.reward_function = UnSupervisedReward(env=self)