コード例 #1
0
ファイル: train.py プロジェクト: zxxia/PCC-RL
def main():
    args = parse_args()
    assert (not args.model_path or args.model_path.endswith(".ckpt"))
    os.makedirs(args.save_dir, exist_ok=True)
    save_args(args, args.save_dir)
    set_seed(args.seed)

    # Initialize model and agent policy
    if args.jump_action:
        pensieve = Pensieve(args.model_path, 6, 6, 3)
    else:
        pensieve = Pensieve(args.model_path)
        # args.seed,
        # args.save_dir,
        # int(args.val_freq / nagents),
        # tensorboard_log=args.tensorboard_log,
    # training_traces, validation_traces,
    training_traces = []
    val_traces = []
    if args.curriculum == "udr":
        config_file = args.config_file
        if args.train_trace_dir:
            all_time, all_bw, all_file_names = load_traces(
                args.train_trace_dir)
            training_traces = [
                AbrTrace(t, bw, link_rtt=80, buffer_thresh=60, name=name)
                for t, bw, name in zip(all_time, all_bw, all_file_names)
            ]

        if args.val_trace_dir:
            all_time, all_bw, all_file_names = load_traces(args.val_trace_dir)
            val_traces = [
                AbrTrace(t, bw, link_rtt=80, buffer_thresh=60, name=name)
                for t, bw, name in zip(all_time, all_bw, all_file_names)
            ]
        train_scheduler = UDRTrainScheduler(
            config_file,
            training_traces,
            percent=args.real_trace_prob,
        )
    elif args.curriculum == "cl1":
        # config_file = args.config_files[0]
        # train_scheduler = CL1TrainScheduler(args.config_files, aurora)
        raise NotImplementedError
    elif args.curriculum == "cl2":
        # config_file = args.config_file
        # train_scheduler = CL2TrainScheduler(
        #     config_file, aurora, args.baseline
        # )
        raise NotImplementedError
    else:
        raise NotImplementedError

    pensieve.train(train_scheduler, val_traces, args.save_dir, args.nagent,
                   args.total_epoch, args.video_size_file_dir)
コード例 #2
0
def main():
    args = parse_args()
    assert args.pretrained_model_path is None or args.pretrained_model_path.endswith(
        ".ckpt")
    os.makedirs(args.save_dir, exist_ok=True)
    save_args(args, args.save_dir)
    set_seed(args.seed + COMM_WORLD.Get_rank() * 100)
    nprocs = COMM_WORLD.Get_size()

    # Initialize model and agent policy
    aurora = Aurora(args.seed + COMM_WORLD.Get_rank() * 100,
                    args.save_dir,
                    int(7200 / nprocs),
                    args.pretrained_model_path,
                    tensorboard_log=args.tensorboard_log)
    # training_traces, validation_traces,
    training_traces = []
    val_traces = []
    if args.train_trace_file:
        with open(args.train_trace_file, 'r') as f:
            for line in f:
                line = line.strip()
                training_traces.append(Trace.load_from_file(line))

    if args.val_trace_file:
        with open(args.val_trace_file, 'r') as f:
            for line in f:
                line = line.strip()
                if args.dataset == 'pantheon':
                    queue = 100  # dummy value
                    # if "ethernet" in line:
                    #     queue = 500
                    # elif "cellular" in line:
                    #     queue = 50
                    # else:
                    #     queue = 100
                    val_traces.append(
                        Trace.load_from_pantheon_file(line,
                                                      queue=queue,
                                                      loss=0))
                elif args.dataset == 'synthetic':
                    val_traces.append(Trace.load_from_file(line))
                else:
                    raise ValueError

    aurora.train(args.randomization_range_file,
                 args.total_timesteps,
                 tot_trace_cnt=args.total_trace_count,
                 tb_log_name=args.exp_name,
                 validation_flag=args.validation,
                 training_traces=training_traces,
                 validation_traces=val_traces,
                 real_trace_prob=args.real_trace_prob)
コード例 #3
0
    def __init__(self):
        # Don't use the common init for the moment
        # common_init(self)
        self.args = self.init_args()

        if self.args.continue_train and self.args.model_dir is None:
            raise Exception("'--model-dir' must be specified when using "
                            "'--continue-train'")

        prepare_dir(self.args)
        self.logger = get_logger(self.args)
        set_utils_logger(self.logger)
        np.random.seed(self.args.seed)
        random.seed(self.args.seed)
        torch.manual_seed(self.args.seed)
        init_device(self.args)
        save_args(self.args)
        save_commit_id(self.args)
        self.tb = TensorBoard(self.args.model_dir)
コード例 #4
0
def main():
    args = parse_args()
    assert (not args.pretrained_model_path
            or args.pretrained_model_path.endswith(".ckpt"))
    os.makedirs(args.save_dir, exist_ok=True)
    save_args(args, args.save_dir)
    set_seed(args.seed + COMM_WORLD.Get_rank() * 100)
    nprocs = COMM_WORLD.Get_size()

    # Initialize model and agent policy
    aurora = Aurora(
        args.seed + COMM_WORLD.Get_rank() * 100,
        args.save_dir,
        int(args.val_freq / nprocs),
        args.pretrained_model_path,
        tensorboard_log=args.tensorboard_log,
    )
    # training_traces, validation_traces,
    training_traces = []
    val_traces = []
    if args.curriculum == "udr":
        config_file = args.config_file
        if args.train_trace_file:
            with open(args.train_trace_file, "r") as f:
                for line in f:
                    line = line.strip()
                    training_traces.append(Trace.load_from_file(line))

        if args.validation and args.val_trace_file:
            with open(args.val_trace_file, "r") as f:
                for line in f:
                    line = line.strip()
                    if args.dataset == "pantheon":
                        queue = 100  # dummy value
                        val_traces.append(
                            Trace.load_from_pantheon_file(line,
                                                          queue=queue,
                                                          loss=0))
                    elif args.dataset == "synthetic":
                        val_traces.append(Trace.load_from_file(line))
                    else:
                        raise ValueError
        train_scheduler = UDRTrainScheduler(
            config_file,
            training_traces,
            percent=args.real_trace_prob,
        )
    elif args.curriculum == "cl1":
        config_file = args.config_files[0]
        train_scheduler = CL1TrainScheduler(args.config_files, aurora)
    elif args.curriculum == "cl2":
        config_file = args.config_file
        train_scheduler = CL2TrainScheduler(config_file, aurora, args.baseline)
    else:
        raise NotImplementedError

    aurora.train(
        config_file,
        args.total_timesteps,
        train_scheduler,
        tb_log_name=args.exp_name,
        validation_traces=val_traces,
    )
コード例 #5
0
def common_init(that):
    """Common initialization of our models. Here is the check list:

        - [√] Parse the input arguments
        - [√] Create necessary folders to save data
        - [√] Set a logger to be used and save the output 
        - [√] Set manual seeds to make results reproductible
        - [√] Init the correct device to be used by pytorch: cpu or cuda:id
        - [√] Save the input arguments used
        - [√] Save the git infos: commit id, repo origin
        - [√] Set a tensorboard object to record stats
        - [√] Set a DataSelector object which handles data samples
        - [√] Set a StatKeeper object which can save arbitrary stats
        - [√] Perform specific initializations based on input params
    """
    that.args = that.init_args()

    if that.args.continue_train and that.args.model_dir is None:
        raise Exception("'--model-dir' must be specified when using "
                        "'--continue-train'")

    prepare_dir(that.args)
    that.logger = get_logger(that.args)
    set_utils_logger(that.logger)
    np.random.seed(that.args.seed)
    random.seed(that.args.seed)
    torch.manual_seed(that.args.seed)
    init_device(that.args)
    save_args(that.args)
    save_commit_id(that.args)
    that.tb = TensorBoard(that.args.model_dir)
    that.ds = DataSelector(that.args)
    that.sk = StatsKeeper(that.args, that.args.stat_folder)

    # Init seq
    if that.args.init_seq == "original":
        # Done by default in DataSelector initialization
        pass
    elif that.args.init_seq.startswith("overlap_"):
        overlap = int(that.args.init_seq.split("_")[1])
        if that.args.bptt % overlap != 0:
            raise Exception(f"overlap must divide '--bptt' (found {overlap})")
        that.ds.current_seq = that.ds.overlap_seq(
            that.args.batch_size, overlap)
    elif that.args.init_seq.startswith("overlapC_"):
        overlap = int(that.args.init_seq.split("_")[1])
        if that.args.bptt % overlap != 0:
            raise Exception(f"overlapC must divide '--bptt' (found {overlap})")
        that.ds.current_seq = that.ds.overlap_c_seq(
            that.args.batch_size, overlap)
    elif that.args.init_seq.startswith("overlapCN_"):
        overlap = int(that.args.init_seq.split("_")[1])
        if that.args.bptt % overlap != 0:
            raise Exception(
                f"overlapCN must divide '--bptt' (found {overlap})")
        that.ds.current_seq = that.ds.overlap_cn_seq(
            that.args.batch_size, overlap)
    elif that.args.init_seq.startswith("overlapCNX_"):
        overlap = int(that.args.init_seq.split("_")[1])
        if that.args.bptt % overlap != 0:
            raise Exception(
                f"overlapCNX must divide '--bptt' (found {overlap})")
        that.ds.current_seq = that.ds.overlap_cnx_seq(
            that.args.batch_size, overlap)
    elif that.args.init_seq.startswith("overlapCX_"):
        overlap = int(that.args.init_seq.split("_")[1])
        if that.args.bptt % overlap != 0:
            raise Exception(
                f"overlapCX must divide '--bptt' (found {overlap})")
        that.ds.current_seq = that.ds.overlap_cx_seq(
            that.args.batch_size, overlap)
    elif that.args.init_seq.startswith("overlapCNF_"):
        overlap = int(that.args.init_seq.split("_")[1])
        if overlap > that.args.bptt:
            raise Exception(
                "overlapCNF must be lower than '--bptt' (found {overlap})")
        that.ds.current_seq = that.ds.overlap_cnf_seq(
            that.args.batch_size, overlap)
    else:
        raise Exception(f"init-seq unkown: {that.args.init_seq}")

    # Type of train_seq
    if that.args.train_seq == "original":
        that.train_seq = that.ds.train_seq
    elif that.args.train_seq.startswith("repeat_"):
        n = int(that.args.train_seq.split("_")[1])
        that.train_seq = lambda: that.ds.repeated_train_seq(n)
    else:
        raise Exception(f"train-seq unkown: {that.args.train_seq}")

    # Shuffling of the train_seq
    if that.args.shuffle_row_seq:
        that.ds.shuffle_row_train_seq()
    if that.args.shuffle_col_seq:
        that.ds.shuffle_col_train_seq()
    if that.args.shuffle_each_row_seq:
        that.ds.shuffle_each_row_train_seq()
    if that.args.shuffle_full_seq:
        that.ds.shuffle_full_train_seq()