Ejemplo n.º 1
0
    def __init__(self, args, shared_damped_model, global_count,
                 global_writer_loss_count, global_writer_quality_count,
                 global_win_event_count, save_dir):
        super(AgentOffpac, self).__init__()
        self.args = args
        self.shared_damped_model = shared_damped_model
        self.global_count = global_count
        self.global_writer_loss_count = global_writer_loss_count
        self.global_writer_quality_count = global_writer_quality_count
        self.global_win_event_count = global_win_event_count
        self.writer_idx_warmup_loss = 0
        # self.eps = self.args.init_epsilon
        self.save_dir = save_dir
        if args.stop_qual_rule == 'naive':
            self.stop_qual_rule = NaiveDecay(initial_eps=args.init_stop_qual,
                                             episode_shrinkage=1,
                                             change_after_n_episodes=5)
        elif args.stop_qual_rule == 'gaussian':
            self.stop_qual_rule = GaussianDecay(args.stop_qual_final,
                                                args.stop_qual_scaling,
                                                args.stop_qual_offset,
                                                args.T_max)
        else:
            self.stop_qual_rule = NaiveDecay(args.init_stop_qual)

        if self.args.eps_rule == "treesearch":
            self.eps_rule = ActionPathTreeNodes()
        elif self.args.eps_rule == "sawtooth":
            self.eps_rule = ExpSawtoothEpsDecay()
        elif self.args.eps_rule == 'gaussian':
            self.eps_rule = GaussianDecay(args.eps_final, args.eps_scaling,
                                          args.eps_offset, args.T_max)
        else:
            self.eps_rule = NaiveDecay(self.eps, 0.00005, 1000, 1)
Ejemplo n.º 2
0
    def __init__(self, args, shared_average_model, global_count,
                 global_writer_loss_count, global_writer_quality_count,
                 global_win_event_count, save_dir):
        super(AgentAcerContinuousTrainer, self).__init__()

        self.args = args
        self.shared_average_model = shared_average_model
        self.global_count = global_count
        self.global_writer_loss_count = global_writer_loss_count
        self.global_writer_quality_count = global_writer_quality_count
        self.global_win_event_count = global_win_event_count
        self.writer_idx_warmup_loss = 0
        # self.eps = self.args.init_epsilon
        self.save_dir = save_dir
        if args.stop_qual_rule == 'naive':
            self.stop_qual_rule = NaiveDecay(initial_eps=args.init_stop_qual,
                                             episode_shrinkage=1,
                                             change_after_n_episodes=5)
        elif args.stop_qual_rule == 'gaussian':
            self.stop_qual_rule = GaussianDecay(args.stop_qual_final,
                                                args.stop_qual_scaling,
                                                args.stop_qual_offset,
                                                args.T_max)
        elif args.stop_qual_rule == 'running_average':
            self.stop_qual_rule = RunningAverage(
                args.stop_qual_ra_bw,
                args.stop_qual_scaling + args.stop_qual_offset,
                args.stop_qual_ra_off)
        else:
            self.stop_qual_rule = NaiveDecay(args.init_stop_qual)

        if self.args.eps_rule == "treesearch":
            self.b_sigma_rule = ActionPathTreeNodes()
        elif self.args.eps_rule == "sawtooth":
            self.b_sigma_rule = ExpSawtoothEpsDecay()
        elif self.args.eps_rule == 'gaussian':
            self.b_sigma_rule = GaussianDecay(args.b_sigma_final,
                                              args.b_sigma_scaling,
                                              args.p_sigma, args.T_max)
        elif self.args.eps_rule == "self_reg_min":
            self.args.T_max = np.inf
            self.b_sigma_rule = FollowLeadMin(
                (args.stop_qual_scaling + args.stop_qual_offset), 1)
        elif self.args.eps_rule == "self_reg_avg":
            self.args.T_max = np.inf
            self.b_sigma_rule = FollowLeadAvg(
                (args.stop_qual_scaling + args.stop_qual_offset) / 4, 2, 1)
        elif self.args.eps_rule == "self_reg_exp_avg":
            self.args.T_max = np.inf
            self.b_sigma_rule = ExponentialAverage(
                (args.stop_qual_scaling + args.stop_qual_offset) / 4, 0.9, 1)
        else:
            self.b_sigma_rule = NaiveDecay(self.eps, 0.00005, 1000, 1)
Ejemplo n.º 3
0
    def __init__(self, cfg, args, global_count, global_writer_loss_count,
                 global_writer_quality_count, global_win_event_count,
                 action_stats_count, save_dir):
        super(AgentSacTrainer_sg_lg, self).__init__()

        self.cfg = cfg
        self.args = args
        self.global_count = global_count
        self.global_writer_loss_count = global_writer_loss_count
        self.global_writer_quality_count = global_writer_quality_count
        self.global_win_event_count = global_win_event_count
        self.action_stats_count = action_stats_count
        # self.eps = self.args.init_epsilon
        self.save_dir = save_dir
        if args.stop_qual_rule == 'naive':
            self.stop_qual_rule = NaiveDecay(initial_eps=args.init_stop_qual,
                                             episode_shrinkage=1,
                                             change_after_n_episodes=5)
        elif args.stop_qual_rule == 'gaussian':
            self.stop_qual_rule = GaussianDecay(args.stop_qual_final,
                                                args.stop_qual_scaling,
                                                args.stop_qual_offset,
                                                args.T_max)
        elif args.stop_qual_rule == 'running_average':
            self.stop_qual_rule = RunningAverage(
                args.stop_qual_ra_bw,
                args.stop_qual_scaling + args.stop_qual_offset,
                args.stop_qual_ra_off)
        else:
            self.stop_qual_rule = Constant(args.stop_qual_final)

        if self.cfg.temperature_regulation == 'follow_quality':
            self.beta_rule = FollowLeadAvg(1, 80, 1)
        elif self.cfg.temperature_regulation == 'constant':
            self.eps_rule = Constant(cfg.init_temperature)