Ejemplo n.º 1
0
def single_eval(args=None):
    if args is None:
        args = command_parser.parse_arguments()

    create_shared_model = model_class(args.model)
    init_agent = agent_class(args.agent_type)

    args.phase = 'eval'
    args.episode_type = 'TestValEpisode'
    args.test_or_val = 'test'

    # if args.num_category != 60:
    #     args.detection_feature_file_name = 'det_feature_{}_categories.hdf5'.format(args.num_category)

    start_time = time.time()
    local_start_time_str = time.strftime('%Y_%m_%d_%H_%M_%S',
                                         time.localtime(start_time))

    tb_log_dir = args.log_dir + "/" + args.title + '_' + args.phase + '_' + local_start_time_str
    log_writer = SummaryWriter(log_dir=tb_log_dir)

    checkpoint = args.load_model

    model = os.path.join(args.save_model_dir, checkpoint)
    args.load_model = model

    # run eval on model
    # args.test_or_val = "val"
    main_eval(args, create_shared_model, init_agent)
Ejemplo n.º 2
0
def main():
    args = flag_parser.parse_arguments()

    create_shared_model = model_class(args.model)
    init_agent = agent_class(args.agent_type)

    args.episode_type = "TestValEpisode"
    args.test_or_val = "test"

    # Get all valid saved_models for the given title and sort by train_ep.
    checkpoints = [(f, f.split("_")) for f in os.listdir(args.save_model_dir)]

    checkpoints = [
        (f, int(s[-3])) for (f, s) in checkpoints
        if len(s) >= 4 and f.startswith(args.title) and f.endswith('dat')
    ]
    checkpoints.sort(key=lambda x: x[1])

    best_model_on_val = None
    best_performance_on_val = 0.0
    for (f, train_ep) in tqdm(checkpoints, desc="Checkpoints."):

        model = os.path.join(args.save_model_dir, f)
        args.load_model = model

        # run eval on model
        args.test_or_val = "test"
        main_eval(args, create_shared_model, init_agent)

        # check if best on val.
        with open(args.results_json, "r") as f:
            results = json.load(f)

        if results["success"] > best_performance_on_val:
            best_model_on_val = model
            best_performance_on_val = results["success"]

    args.test_or_val = "test"
    args.load_model = best_model_on_val
    main_eval(args, create_shared_model, init_agent)

    with open(args.results_json, "r") as f:
        results = json.load(f)

    print(
        tabulate(
            [
                ["SPL >= 1:", results["GreaterThan/1/spl"]],
                ["Success >= 1:", results["GreaterThan/1/success"]],
                ["SPL >= 5:", results["GreaterThan/5/spl"]],
                ["Success >= 5:", results["GreaterThan/5/success"]],
            ],
            headers=["Metric", "Result"],
            tablefmt="orgtbl",
        ))

    print("Best model:", args.load_model)
Ejemplo n.º 3
0
def main():
    setproctitle.setproctitle("Train/Test Manager")
    args = flag_parser.parse_arguments()

    if args.model == "BaseModel" or args.model == "GCN":
        args.learned_loss = False
        args.num_steps = 50
        target = nonadaptivea3c_val if args.eval else nonadaptivea3c_train
    # else:
    #     args.learned_loss = True
    #     args.num_steps = 6
    #     target = savn_val if args.eval else savn_train

    create_shared_model = model_class(args.model)
    init_agent = agent_class(args.agent_type)
    optimizer_type = optimizer_class(args.optimizer)
    # print('shared model created')
    if args.eval:
        main_eval(args, create_shared_model, init_agent)
        return

    start_time = time.time()
    local_start_time_str = time.strftime("%Y-%m-%d_%H:%M:%S",
                                         time.localtime(start_time))
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    random.seed(args.seed)

    # print('seeding done')

    if args.log_dir is not None:
        tb_log_dir = args.log_dir + "/" + args.title + "-" + local_start_time_str
        log_writer = SummaryWriter(log_dir=tb_log_dir)
    else:
        log_writer = SummaryWriter(comment=args.title)

    if args.gpu_ids == -1:
        args.gpu_ids = [-1]
    else:
        # print('something to do with cuda')
        torch.cuda.manual_seed(args.seed)
        mp.set_start_method("spawn")

    shared_model = create_shared_model(args)

    train_total_ep = 0
    n_frames = 0

    if shared_model is not None:
        # print('shared model is being created')
        shared_model.share_memory()
        optimizer = optimizer_type(
            filter(lambda p: p.requires_grad, shared_model.parameters()), args)
        optimizer.share_memory()
        print(shared_model)
        # print('!!!!!!!!!!!!')
    else:
        assert (args.agent_type == "RandomNavigationAgent"
                ), "The model is None but agent is not random agent"
        optimizer = None

    processes = []

    end_flag = mp.Value(ctypes.c_bool, False)

    train_res_queue = mp.Queue()

    for rank in range(0, args.workers):
        print('Process {} being created'.format(rank))
        p = mp.Process(
            target=target,
            args=(
                rank,
                args,
                create_shared_model,
                shared_model,
                init_agent,
                optimizer,
                train_res_queue,
                end_flag,
            ),
        )
        p.start()
        processes.append(p)
        time.sleep(0.1)

    print("Train agents created.")

    train_thin = args.train_thin
    train_scalars = ScalarMeanTracker()

    try:
        while train_total_ep < args.max_ep:
            print('total train ep: {} of {}'.format(train_total_ep,
                                                    args.max_ep))
            print('Cuda available: {}'.format(torch.cuda.is_available()))
            train_result = train_res_queue.get()
            print('Got the train result from the queue')
            train_scalars.add_scalars(train_result)
            train_total_ep += 1
            n_frames += train_result["ep_length"]
            if (train_total_ep % train_thin) == 0:
                log_writer.add_scalar("n_frames", n_frames, train_total_ep)
                tracked_means = train_scalars.pop_and_reset()
                for k in tracked_means:
                    log_writer.add_scalar(k + "/train", tracked_means[k],
                                          train_total_ep)

            if (train_total_ep % args.ep_save_freq) == 0:

                print(n_frames)
                if not os.path.exists(args.save_model_dir):
                    os.makedirs(args.save_model_dir)
                state_to_save = shared_model.state_dict()
                save_path = os.path.join(
                    args.save_model_dir,
                    "{0}_{1}_{2}_{3}.dat".format(args.title, n_frames,
                                                 train_total_ep,
                                                 local_start_time_str),
                )
                torch.save(state_to_save, save_path)

    finally:
        log_writer.close()
        end_flag.value = True
        for p in processes:
            time.sleep(0.1)
            p.join()
Ejemplo n.º 4
0
def main():
    setproctitle.setproctitle("Train/Test Manager")
    args = flag_parser.parse_arguments()

    if args.model == "BaseModel" or args.model == "GCN_MLP" or args.model == "GCN" or args.model == "GCN_GRU":
        args.learned_loss = False
        args.num_steps = 50
        target = nonadaptivea3c_val if args.eval else nonadaptivea3c_train

    create_shared_model = model_class(args.model)
    init_agent = agent_class(args.agent_type)
    optimizer_type = optimizer_class(args.optimizer)

    if args.eval:
        main_eval(args, create_shared_model, init_agent)
        return

    model_to_open = args.load_model

    if model_to_open != "":
        shared_model = create_shared_model(args)
        optimizer = optimizer_type(
            filter(lambda p: p.requires_grad, shared_model.parameters()), args)
        saved_state = torch.load(model_to_open,
                                 map_location=lambda storage, loc: storage)
        shared_model.load_state_dict(saved_state['model'])
        optimizer.load_state_dict(saved_state['optimizer'])
        optimizer.share_memory()
        train_total_ep = saved_state['train_total_ep']
        n_frames = saved_state['n_frames']

    else:
        shared_model = create_shared_model(args)

        train_total_ep = 0
        n_frames = 0

        if shared_model is not None:
            shared_model.share_memory()
            optimizer = optimizer_type(
                filter(lambda p: p.requires_grad, shared_model.parameters()),
                args)
            optimizer.share_memory()
            print(shared_model)
        else:
            assert (args.agent_type == "RandomNavigationAgent"
                    ), "The model is None but agent is not random agent"
            optimizer = None

    processes = []

    end_flag = mp.Value(ctypes.c_bool, False)

    train_res_queue = mp.Queue()

    start_time = time.time()
    local_start_time_str = time.strftime("%Y-%m-%d_%H:%M:%S",
                                         time.localtime(start_time))
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    random.seed(args.seed)

    if args.log_dir is not None:
        tb_log_dir = args.log_dir + "/" + args.title + "-" + local_start_time_str
        log_writer = SummaryWriter(log_dir=tb_log_dir)
    else:
        log_writer = SummaryWriter(comment=args.title)

    if args.gpu_ids == -1:
        args.gpu_ids = [-1]
    else:
        torch.cuda.manual_seed(args.seed)

    for rank in range(0, args.workers):
        p = mp.Process(
            target=target,
            args=(
                rank,
                args,
                create_shared_model,
                shared_model,
                init_agent,
                optimizer,
                train_res_queue,
                end_flag,
            ),
        )
        p.start()
        processes.append(p)
        time.sleep(0.1)

    print("Train agents created.")

    train_thin = args.train_thin
    train_scalars = ScalarMeanTracker()

    print(train_total_ep)
    print(optimizer)
    try:
        while train_total_ep < args.max_ep:

            train_result = train_res_queue.get()
            train_scalars.add_scalars(train_result)
            train_total_ep += 1
            n_frames += train_result["ep_length"]
            if (train_total_ep % train_thin) == 0:
                log_writer.add_scalar("n_frames", n_frames, train_total_ep)
                tracked_means = train_scalars.pop_and_reset()
                for k in tracked_means:
                    log_writer.add_scalar(k + "/train", tracked_means[k],
                                          train_total_ep)

            if (train_total_ep % args.ep_save_freq) == 0:

                print(n_frames)
                if not os.path.exists(args.save_model_dir):
                    os.makedirs(args.save_model_dir)
                state_to_save = shared_model.state_dict()
                save_path = os.path.join(
                    args.save_model_dir,
                    "{0}_{1}_{2}_{3}.dat".format(args.title, n_frames,
                                                 train_total_ep,
                                                 local_start_time_str),
                )
                save_dict = {
                    'model': state_to_save,
                    'train_total_ep': train_total_ep,
                    'optimizer': optimizer.state_dict(),
                    'n_frames': n_frames
                }
                torch.save(save_dict, save_path)
                #torch.save(state_to_save, save_path)

    finally:
        log_writer.close()
        end_flag.value = True
        for p in processes:
            time.sleep(0.1)
            p.join()
Ejemplo n.º 5
0
def main():
    # 设置进程名称
    setproctitle.setproctitle("Train/Test Manager")

    # 获取命令行参数
    args = flag_parser.parse_arguments()

    if args.model == "SAVN":
        args.learned_loss = True
        args.num_steps = 6
        target = savn_val if args.eval else savn_train
    else:
        args.learned_loss = False
        args.num_steps = args.max_episode_length
        target = nonadaptivea3c_val if args.eval else nonadaptivea3c_train

    # 检查pinned_scene 和 data_source 是否冲突
    if args.data_source == "ithor" and args.pinned_scene == True:
        raise Exception(
            "Cannot set pinned_scene to true when using ithor dataset")

    # 获取模型对象类别, 未创建对象 e.g. <class 'models.basemodel.BaseModel'>
    create_shared_model = model_class(args.model)
    # 获取agent类别,未创建对象 default <class 'agents.navigation_agent.NavigationAgent'>
    init_agent = agent_class(args.agent_type)
    # 获取优化器对象类别,未创建对象 default <class 'optimizers.shared_adam.SharedAdam'>
    optimizer_type = optimizer_class(args.optimizer)
    ########################  测试阶段 ################################
    if args.eval:
        main_eval(args, create_shared_model, init_agent)
        return


####################### 训练阶段 #################################
    start_time = time.time()
    local_start_time_str = time.strftime("%Y-%m-%d_%H:%M:%S",
                                         time.localtime(start_time))
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    random.seed(args.seed)

    # 设置日志参数
    if args.log_dir is not None:
        tb_log_dir = args.log_dir + "/" + args.title + "-" + local_start_time_str
        log_writer = SummaryWriter(log_dir=tb_log_dir)
    else:
        log_writer = SummaryWriter(comment=args.title)

    if args.gpu_ids == -1:
        args.gpu_ids = [-1]
    else:
        torch.cuda.manual_seed(args.seed)
        mp.set_start_method("spawn")

    # 创建一个 torch.nn.Module的子类对象
    shared_model = create_shared_model(args)

    optimizer = optimizer_type(
        filter(lambda p: p.requires_grad, shared_model.parameters()), args)
    # 加载预先保存的模型
    train_total_ep, n_frames = load_checkpoint(args, shared_model, optimizer)
    # TODO: delete this after debug
    # train_total_ep = 1000001

    if shared_model is not None:
        # 模型在多进程间共享参数 这个参数是torch.mutiprocessing 调用fork之前必须调用的方法
        shared_model.share_memory()
        # 创建一个 torch.optim.Optimizer的子类对象
        # filter 函数把model中所有需要梯度更新的变量 作为参数送到optimizer的constructor中

        optimizer.share_memory()
        print(shared_model)
    else:
        assert (args.agent_type == "RandomNavigationAgent"
                ), "The model is None but agent is not random agent"
        optimizer = None

    processes = []

    end_flag = mp.Value(ctypes.c_bool, False)
    global_ep = mp.Value(ctypes.c_int)

    global_ep.value = train_total_ep

    # 多进程共享资源队列
    train_res_queue = mp.Queue()
    # 创建多进程
    # target 进程执行目标函数
    #
    for rank in range(0, args.workers):
        p = mp.Process(
            target=target,
            args=(rank, args, create_shared_model, shared_model, init_agent,
                  optimizer, train_res_queue, end_flag, global_ep),
        )
        p.start()
        processes.append(p)
        time.sleep(0.1)

    print("Train agents created.")

    train_thin = args.train_thin
    train_scalars = ScalarMeanTracker()

    # 主线程
    try:
        while train_total_ep < args.max_ep:

            train_result = train_res_queue.get()
            train_scalars.add_scalars(train_result)
            train_total_ep += 1
            global_ep.value = train_total_ep

            n_frames += train_result["ep_length"]
            if (train_total_ep % train_thin) == 0:
                log_writer.add_scalar("n_frames", n_frames, train_total_ep)
                tracked_means = train_scalars.pop_and_reset()
                for k in tracked_means:
                    log_writer.add_scalar(k + "/train", tracked_means[k],
                                          train_total_ep)

            if (train_total_ep % args.ep_save_freq) == 0:

                print(n_frames)
                if not os.path.exists(args.save_model_dir):
                    os.makedirs(args.save_model_dir)
                state_to_save = shared_model.state_dict()
                save_path = os.path.join(
                    args.save_model_dir,
                    "{0}_{1}_{2}_{3}.dat".format(args.title, n_frames,
                                                 train_total_ep,
                                                 local_start_time_str),
                )
                torch.save(state_to_save, save_path)

            if (train_total_ep % args.ep_save_ckpt) == 0:
                print("save check point at episode {}".format(train_total_ep))
                checkpoint = {
                    'train_total_ep': train_total_ep,
                    'n_frames': n_frames,
                    'shared_model': shared_model.state_dict(),
                    'optimizer': optimizer.state_dict()
                }
                checkpoint_path = os.path.join(args.save_model_dir,
                                               "checkpoint.dat")
                torch.save(checkpoint, checkpoint_path)

    finally:
        log_writer.close()
        end_flag.value = True
        for p in processes:
            time.sleep(0.1)
            p.join()
Ejemplo n.º 6
0
def main():
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    args = flag_parser.parse_arguments()

    create_shared_model = model_class(args.model)
    init_agent = agent_class(args.agent_type)

    args.episode_type = "TestValEpisode"
    args.test_or_val = "val"

    tb_log_dir = args.log_dir + "/" + '{}_{}_{}'.format(
        args.title, args.test_or_val,
        time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime(time.time())))
    log_writer = SummaryWriter(log_dir=tb_log_dir)

    print('Start Loading!')
    optimal_action_path = './data/AI2thor_Combine_Dataset/Optimal_Path_Combine.json'
    with open(optimal_action_path, 'r') as read_file:
        optimal_action_dict = json.load(read_file)
    manager = Manager()
    optimal_action = manager.dict()
    optimal_action.update(optimal_action_dict)
    glove_file_path = './data/AI2thor_Combine_Dataset/det_feature_512_eval.hdf5'
    glove_file = hdf5_to_dict(glove_file_path)
    print('Loading Success!')

    # Get all valid saved_models for the given title and sort by train_ep.
    checkpoints = [(f, f.split("_")) for f in os.listdir(args.save_model_dir)]
    checkpoints = [(f, int(s[-3])) for (f, s) in checkpoints
                   if len(s) >= 4 and f.startswith(args.title)]
    checkpoints.sort(key=lambda x: x[1])

    best_model_on_val = None
    best_performance_on_val = 0.0
    for (f, train_ep) in tqdm(checkpoints, desc="Checkpoints."):

        model = os.path.join(args.save_model_dir, f)
        args.load_model = model

        # run eval on model
        args.test_or_val = "val"
        main_eval(args, create_shared_model, init_agent, glove_file,
                  optimal_action)

        # check if best on val.
        with open(args.results_json, "r") as f:
            results = json.load(f)

        if results["success"] > best_performance_on_val:
            best_model_on_val = model
            best_performance_on_val = results["success"]

        log_writer.add_scalar("val/success", results["success"], train_ep)
        log_writer.add_scalar("val/spl", results["spl"], train_ep)

        # run on test.
        args.test_or_val = "test"
        main_eval(args, create_shared_model, init_agent, glove_file,
                  optimal_action)
        with open(args.results_json, "r") as f:
            results = json.load(f)

        log_writer.add_scalar("test/success", results["success"], train_ep)
        log_writer.add_scalar("test/spl", results["spl"], train_ep)

    args.record_route = True
    args.test_or_val = "test"
    args.load_model = best_model_on_val
    main_eval(args, create_shared_model, init_agent, glove_file,
              optimal_action)

    with open(args.results_json, "r") as f:
        results = json.load(f)

    print(
        tabulate(
            [
                ["SPL >= 1:", results["GreaterThan/1/spl"]],
                ["Success >= 1:", results["GreaterThan/1/success"]],
                ["SPL >= 5:", results["GreaterThan/5/spl"]],
                ["Success >= 5:", results["GreaterThan/5/success"]],
            ],
            headers=["Metric", "Result"],
            tablefmt="orgtbl",
        ))

    print("Best model:", args.load_model)
Ejemplo n.º 7
0
def main():
    setproctitle.setproctitle("Train/Test Manager")
    args = flag_parser.parse_arguments()

    if args.model == "BaseModel" or args.model == "GCN":
        args.learned_loss = False
        args.num_steps = 50
        target = nonadaptivea3c_val if args.eval else nonadaptivea3c_train
    else:
        args.learned_loss = True
        args.num_steps = 6
        target = savn_val if args.eval else savn_train

    create_shared_model = model_class(args.model)
    init_agent = agent_class(args.agent_type)
    optimizer_type = optimizer_class(args.optimizer)

    if args.eval:
        main_eval(args, create_shared_model, init_agent)
        return

    start_time = time.time()
    local_start_time_str = time.strftime("%Y-%m-%d_%H:%M:%S",
                                         time.localtime(start_time))
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    random.seed(args.seed)

    if args.log_dir is not None:
        tb_log_dir = args.log_dir + "/" + args.title + "-" + local_start_time_str
        log_writer = SummaryWriter(log_dir=tb_log_dir)
    else:
        log_writer = SummaryWriter(comment=args.title)

    if args.gpu_ids == -1:
        args.gpu_ids = [-1]
    else:
        torch.cuda.manual_seed(args.seed)
        mp.set_start_method("spawn")

    shared_model = create_shared_model(args)

    train_total_ep = 0
    n_frames = 0

    if shared_model is not None:
        shared_model.share_memory()
        optimizer = optimizer_type(
            filter(lambda p: p.requires_grad, shared_model.parameters()), args)
        optimizer.share_memory()
        print(shared_model)
    else:
        assert (args.agent_type == "RandomNavigationAgent"
                ), "The model is None but agent is not random agent"
        optimizer = None

    processes = []

    print('Start Loading!')
    optimal_action_path = './data/AI2thor_Combine_Dataset/Optimal_Path_Combine.json'
    with open(optimal_action_path, 'r') as read_file:
        optimal_action_dict = json.load(read_file)
    manager = Manager()
    optimal_action = manager.dict()
    optimal_action.update(optimal_action_dict)
    glove_file_path = './data/AI2thor_Combine_Dataset/det_feature_512_train.hdf5'
    glove_file = hdf5_to_dict(glove_file_path)
    # det_gt_path = './data/AI2thor_Combine_Dataset/Instance_Detection_Combine.pkl'
    # with open(det_gt_path, 'rb') as read_file:
    #     det_gt = pickle.load(read_file)
    print('Loading Success!')

    end_flag = mp.Value(ctypes.c_bool, False)

    train_res_queue = mp.Queue()

    for rank in range(0, args.workers):
        p = mp.Process(
            target=target,
            args=(
                rank,
                args,
                create_shared_model,
                shared_model,
                init_agent,
                optimizer,
                train_res_queue,
                end_flag,
                glove_file,
                optimal_action,
                # det_gt,
            ),
        )
        p.start()
        processes.append(p)
        time.sleep(0.1)

    print("Train agents created.")

    train_thin = args.train_thin
    train_scalars = ScalarMeanTracker()

    # start_ep_time = time.time()

    try:
        while train_total_ep < args.max_ep:

            train_result = train_res_queue.get()
            train_scalars.add_scalars(train_result)
            train_total_ep += 1
            n_frames += train_result["ep_length"]
            # if train_total_ep % 10 == 0:
            #     print(n_frames / train_total_ep)
            #     print((time.time() - start_ep_time) / train_total_ep)
            if (train_total_ep % train_thin) == 0:
                log_writer.add_scalar("n_frames", n_frames, train_total_ep)
                tracked_means = train_scalars.pop_and_reset()
                for k in tracked_means:
                    log_writer.add_scalar(k + "/train", tracked_means[k],
                                          train_total_ep)

            if (train_total_ep % args.ep_save_freq) == 0:

                print(n_frames)
                if not os.path.exists(args.save_model_dir):
                    os.makedirs(args.save_model_dir)
                state_to_save = shared_model.state_dict()
                save_path = os.path.join(
                    args.save_model_dir,
                    "{0}_{1}_{2}_{3}.dat".format(args.title, n_frames,
                                                 train_total_ep,
                                                 local_start_time_str),
                )
                torch.save(state_to_save, save_path)

    finally:
        log_writer.close()
        end_flag.value = True
        for p in processes:
            time.sleep(0.1)
            p.join()
Ejemplo n.º 8
0
def full_eval(args=None):
    if args is None:
        args = command_parser.parse_arguments()

    create_shared_model = model_class(args.model)
    init_agent = agent_class(args.agent_type)

    args.phase = 'eval'
    args.episode_type = 'TestValEpisode'
    args.test_or_val = 'val'

    # if args.num_category != 60:
    #     args.detection_feature_file_name = 'det_feature_{}_categories.hdf5'.format(args.num_category)

    start_time = time.time()
    local_start_time_str = time.strftime('%Y_%m_%d_%H_%M_%S',
                                         time.localtime(start_time))

    tb_log_dir = args.log_dir + "/" + args.title + '_' + args.phase + '_' + local_start_time_str
    log_writer = SummaryWriter(log_dir=tb_log_dir)

    # Get all valid saved_models for the given title and sort by train_ep.
    checkpoints = [(f, f.split("_")) for f in os.listdir(args.save_model_dir)]
    checkpoints = [(f, int(s[-7])) for (f, s) in checkpoints
                   if len(s) >= 4 and f.startswith(args.title)
                   and int(s[-7]) >= args.test_start_from]
    checkpoints.sort(key=lambda x: x[1])

    best_model_on_val = None
    best_performance_on_val = 0.0
    for (f, train_ep) in tqdm(checkpoints, desc="Checkpoints."):

        model = os.path.join(args.save_model_dir, f)
        args.load_model = model

        # run eval on model
        # args.test_or_val = "val"
        args.test_or_val = "test"
        main_eval(args, create_shared_model, init_agent)

        # check if best on val.
        with open(args.results_json, "r") as f:
            results = json.load(f)

        if results["success"] > best_performance_on_val:
            best_model_on_val = model
            best_performance_on_val = results["success"]

        log_writer.add_scalar("val/success", results["success"], train_ep)
        log_writer.add_scalar("val/spl", results["spl"], train_ep)

        if args.include_test:
            args.test_or_val = "test"
            main_eval(args, create_shared_model, init_agent)
            with open(args.results_json, "r") as f:
                results = json.load(f)

            log_writer.add_scalar("test/success", results["success"], train_ep)
            log_writer.add_scalar("test/spl", results["spl"], train_ep)

    args.test_or_val = "test"
    args.load_model = best_model_on_val
    main_eval(args, create_shared_model, init_agent)

    with open(args.results_json, "r") as f:
        results = json.load(f)

    print(
        tabulate(
            [
                ["SPL >= 1:", results["GreaterThan/1/spl"]],
                ["Success >= 1:", results["GreaterThan/1/success"]],
                ["SPL >= 5:", results["GreaterThan/5/spl"]],
                ["Success >= 5:", results["GreaterThan/5/success"]],
            ],
            headers=["Metric", "Result"],
            tablefmt="orgtbl",
        ))

    print("Best model:", args.load_model)
Ejemplo n.º 9
0
def main():
    setproctitle.setproctitle("Train/Test Manager")
    args = command_parser.parse_arguments()

    print('Training started from: {}'.format(
        time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())))
    )

    args.learned_loss = False
    args.num_steps = 50
    target = a3c_val if args.eval else a3c_train

    if args.csiro:
        args.data_dir = './data/'
    else:
        check_data(args)
    scenes = loading_scene_list(args)

    create_shared_model = model_class(args.model)
    init_agent = agent_class(args.agent_type)
    optimizer_type = optimizer_class(args.optimizer)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    random.seed(args.seed)

    if args.eval:
        main_eval(args, create_shared_model, init_agent)
        return

    start_time = time.time()
    local_start_time_str = time.strftime(
        '%Y_%m_%d_%H_%M_%S', time.localtime(start_time)
    )

    tb_log_dir = args.log_dir + '/' + args.title + '_' + args.phase + '_' + local_start_time_str
    log_writer = SummaryWriter(log_dir=tb_log_dir)

    if args.gpu_ids == -1:
        args.gpu_ids = [-1]
    else:
        torch.cuda.manual_seed(args.seed)
        mp.set_start_method("spawn")

    shared_model = create_shared_model(args)

    train_total_ep = 0
    n_frames = 0

    if args.continue_training is not None:
        saved_state = torch.load(
            args.continue_training, map_location=lambda storage, loc: storage
        )
        shared_model.load_state_dict(saved_state)

        train_total_ep = int(args.continue_training.split('_')[-7])
        n_frames = int(args.continue_training.split('_')[-8])

    if args.fine_tuning is not None:
        saved_state = torch.load(
            args.fine_tuning, map_location=lambda storage, loc: storage
        )
        model_dict = shared_model.state_dict()
        pretrained_dict = {k: v for k, v in saved_state.items() if (k in model_dict and v.shape == model_dict[k].shape)}
        model_dict.update(pretrained_dict)
        shared_model.load_state_dict(model_dict)

    if args.update_meta_network:
        for layer, parameters in shared_model.named_parameters():
            if not layer.startswith('meta'):
                parameters.requires_grad = False

    shared_model.share_memory()
    if args.fine_tune_graph:
        optimizer = optimizer_type(
            [
                {'params': [v for k, v in shared_model.named_parameters() if v.requires_grad and not k.startswith('graph')],
                 'lr': 0.00001},
                {'params': [v for k, v in shared_model.named_parameters() if v.requires_grad and k.startswith('graph')],
                 'lr': args.lr},
            ]
        )
    else:
        optimizer = optimizer_type(
            [v for k, v in shared_model.named_parameters() if v.requires_grad], lr=args.lr
        )
    optimizer.share_memory()
    print(shared_model)

    processes = []

    end_flag = mp.Value(ctypes.c_bool, False)
    train_res_queue = mp.Queue()

    for rank in range(0, args.workers):
        p = mp.Process(
            target=target,
            args=(
                rank,
                args,
                create_shared_model,
                shared_model,
                init_agent,
                optimizer,
                train_res_queue,
                end_flag,
                scenes,
            ),
        )
        p.start()
        processes.append(p)
        time.sleep(0.1)

    print("Train agents created.")

    train_thin = args.train_thin
    train_scalars = ScalarMeanTracker()

    try:
        while train_total_ep < args.max_ep:

            train_result = train_res_queue.get()
            train_scalars.add_scalars(train_result)
            train_total_ep += 1
            n_frames += train_result['ep_length']

            if (train_total_ep % train_thin) == 0:
                log_writer.add_scalar('n_frames', n_frames, train_total_ep)
                tracked_means = train_scalars.pop_and_reset()
                for k in tracked_means:
                    log_writer.add_scalar(
                        k + '/train', tracked_means[k], train_total_ep
                    )

            if (train_total_ep % args.ep_save_freq) == 0:

                print('{}: {}'.format(train_total_ep, n_frames))
                if not os.path.exists(args.save_model_dir):
                    os.makedirs(args.save_model_dir)
                state_to_save = shared_model.state_dict()
                save_path = os.path.join(
                    args.save_model_dir,
                    '{0}_{1}_{2}_{3}.dat'.format(
                        args.title, n_frames, train_total_ep, local_start_time_str
                    ),
                )
                torch.save(state_to_save, save_path)

    finally:
        log_writer.close()
        end_flag.value = True
        for p in processes:
            time.sleep(0.1)
            p.join()

    if args.test_after_train:
        full_eval()