Exemplo n.º 1
0
def main():
    setproctitle.setproctitle("Train/Test Manager")
    args = flag_parser.parse_arguments()

    if args.model == "BaseModel" or args.model == "GCN_MLP" or args.model == "GCN" or args.model == "GCN_GRU":
        args.learned_loss = False
        args.num_steps = 50
        target = nonadaptivea3c_val if args.eval else nonadaptivea3c_train

    create_shared_model = model_class(args.model)
    init_agent = agent_class(args.agent_type)
    optimizer_type = optimizer_class(args.optimizer)

    if args.eval:
        main_eval(args, create_shared_model, init_agent)
        return

    model_to_open = args.load_model

    if model_to_open != "":
        shared_model = create_shared_model(args)
        optimizer = optimizer_type(
            filter(lambda p: p.requires_grad, shared_model.parameters()), args)
        saved_state = torch.load(model_to_open,
                                 map_location=lambda storage, loc: storage)
        shared_model.load_state_dict(saved_state['model'])
        optimizer.load_state_dict(saved_state['optimizer'])
        optimizer.share_memory()
        train_total_ep = saved_state['train_total_ep']
        n_frames = saved_state['n_frames']

    else:
        shared_model = create_shared_model(args)

        train_total_ep = 0
        n_frames = 0

        if shared_model is not None:
            shared_model.share_memory()
            optimizer = optimizer_type(
                filter(lambda p: p.requires_grad, shared_model.parameters()),
                args)
            optimizer.share_memory()
            print(shared_model)
        else:
            assert (args.agent_type == "RandomNavigationAgent"
                    ), "The model is None but agent is not random agent"
            optimizer = None

    processes = []

    end_flag = mp.Value(ctypes.c_bool, False)

    train_res_queue = mp.Queue()

    start_time = time.time()
    local_start_time_str = time.strftime("%Y-%m-%d_%H:%M:%S",
                                         time.localtime(start_time))
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    random.seed(args.seed)

    if args.log_dir is not None:
        tb_log_dir = args.log_dir + "/" + args.title + "-" + local_start_time_str
        log_writer = SummaryWriter(log_dir=tb_log_dir)
    else:
        log_writer = SummaryWriter(comment=args.title)

    if args.gpu_ids == -1:
        args.gpu_ids = [-1]
    else:
        torch.cuda.manual_seed(args.seed)

    for rank in range(0, args.workers):
        p = mp.Process(
            target=target,
            args=(
                rank,
                args,
                create_shared_model,
                shared_model,
                init_agent,
                optimizer,
                train_res_queue,
                end_flag,
            ),
        )
        p.start()
        processes.append(p)
        time.sleep(0.1)

    print("Train agents created.")

    train_thin = args.train_thin
    train_scalars = ScalarMeanTracker()

    print(train_total_ep)
    print(optimizer)
    try:
        while train_total_ep < args.max_ep:

            train_result = train_res_queue.get()
            train_scalars.add_scalars(train_result)
            train_total_ep += 1
            n_frames += train_result["ep_length"]
            if (train_total_ep % train_thin) == 0:
                log_writer.add_scalar("n_frames", n_frames, train_total_ep)
                tracked_means = train_scalars.pop_and_reset()
                for k in tracked_means:
                    log_writer.add_scalar(k + "/train", tracked_means[k],
                                          train_total_ep)

            if (train_total_ep % args.ep_save_freq) == 0:

                print(n_frames)
                if not os.path.exists(args.save_model_dir):
                    os.makedirs(args.save_model_dir)
                state_to_save = shared_model.state_dict()
                save_path = os.path.join(
                    args.save_model_dir,
                    "{0}_{1}_{2}_{3}.dat".format(args.title, n_frames,
                                                 train_total_ep,
                                                 local_start_time_str),
                )
                save_dict = {
                    'model': state_to_save,
                    'train_total_ep': train_total_ep,
                    'optimizer': optimizer.state_dict(),
                    'n_frames': n_frames
                }
                torch.save(save_dict, save_path)
                #torch.save(state_to_save, save_path)

    finally:
        log_writer.close()
        end_flag.value = True
        for p in processes:
            time.sleep(0.1)
            p.join()
Exemplo n.º 2
0
def main():
    setproctitle.setproctitle("Train/Test Manager")
    args = flag_parser.parse_arguments()

    if args.model == "BaseModel" or args.model == "GCN":
        args.learned_loss = False
        args.num_steps = 50
        target = nonadaptivea3c_val if args.eval else nonadaptivea3c_train
    # else:
    #     args.learned_loss = True
    #     args.num_steps = 6
    #     target = savn_val if args.eval else savn_train

    create_shared_model = model_class(args.model)
    init_agent = agent_class(args.agent_type)
    optimizer_type = optimizer_class(args.optimizer)
    # print('shared model created')
    if args.eval:
        main_eval(args, create_shared_model, init_agent)
        return

    start_time = time.time()
    local_start_time_str = time.strftime("%Y-%m-%d_%H:%M:%S",
                                         time.localtime(start_time))
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    random.seed(args.seed)

    # print('seeding done')

    if args.log_dir is not None:
        tb_log_dir = args.log_dir + "/" + args.title + "-" + local_start_time_str
        log_writer = SummaryWriter(log_dir=tb_log_dir)
    else:
        log_writer = SummaryWriter(comment=args.title)

    if args.gpu_ids == -1:
        args.gpu_ids = [-1]
    else:
        # print('something to do with cuda')
        torch.cuda.manual_seed(args.seed)
        mp.set_start_method("spawn")

    shared_model = create_shared_model(args)

    train_total_ep = 0
    n_frames = 0

    if shared_model is not None:
        # print('shared model is being created')
        shared_model.share_memory()
        optimizer = optimizer_type(
            filter(lambda p: p.requires_grad, shared_model.parameters()), args)
        optimizer.share_memory()
        print(shared_model)
        # print('!!!!!!!!!!!!')
    else:
        assert (args.agent_type == "RandomNavigationAgent"
                ), "The model is None but agent is not random agent"
        optimizer = None

    processes = []

    end_flag = mp.Value(ctypes.c_bool, False)

    train_res_queue = mp.Queue()

    for rank in range(0, args.workers):
        print('Process {} being created'.format(rank))
        p = mp.Process(
            target=target,
            args=(
                rank,
                args,
                create_shared_model,
                shared_model,
                init_agent,
                optimizer,
                train_res_queue,
                end_flag,
            ),
        )
        p.start()
        processes.append(p)
        time.sleep(0.1)

    print("Train agents created.")

    train_thin = args.train_thin
    train_scalars = ScalarMeanTracker()

    try:
        while train_total_ep < args.max_ep:
            print('total train ep: {} of {}'.format(train_total_ep,
                                                    args.max_ep))
            print('Cuda available: {}'.format(torch.cuda.is_available()))
            train_result = train_res_queue.get()
            print('Got the train result from the queue')
            train_scalars.add_scalars(train_result)
            train_total_ep += 1
            n_frames += train_result["ep_length"]
            if (train_total_ep % train_thin) == 0:
                log_writer.add_scalar("n_frames", n_frames, train_total_ep)
                tracked_means = train_scalars.pop_and_reset()
                for k in tracked_means:
                    log_writer.add_scalar(k + "/train", tracked_means[k],
                                          train_total_ep)

            if (train_total_ep % args.ep_save_freq) == 0:

                print(n_frames)
                if not os.path.exists(args.save_model_dir):
                    os.makedirs(args.save_model_dir)
                state_to_save = shared_model.state_dict()
                save_path = os.path.join(
                    args.save_model_dir,
                    "{0}_{1}_{2}_{3}.dat".format(args.title, n_frames,
                                                 train_total_ep,
                                                 local_start_time_str),
                )
                torch.save(state_to_save, save_path)

    finally:
        log_writer.close()
        end_flag.value = True
        for p in processes:
            time.sleep(0.1)
            p.join()
Exemplo n.º 3
0
def main():
    # 设置进程名称
    setproctitle.setproctitle("Train/Test Manager")

    # 获取命令行参数
    args = flag_parser.parse_arguments()

    if args.model == "SAVN":
        args.learned_loss = True
        args.num_steps = 6
        target = savn_val if args.eval else savn_train
    else:
        args.learned_loss = False
        args.num_steps = args.max_episode_length
        target = nonadaptivea3c_val if args.eval else nonadaptivea3c_train

    # 检查pinned_scene 和 data_source 是否冲突
    if args.data_source == "ithor" and args.pinned_scene == True:
        raise Exception(
            "Cannot set pinned_scene to true when using ithor dataset")

    # 获取模型对象类别, 未创建对象 e.g. <class 'models.basemodel.BaseModel'>
    create_shared_model = model_class(args.model)
    # 获取agent类别,未创建对象 default <class 'agents.navigation_agent.NavigationAgent'>
    init_agent = agent_class(args.agent_type)
    # 获取优化器对象类别,未创建对象 default <class 'optimizers.shared_adam.SharedAdam'>
    optimizer_type = optimizer_class(args.optimizer)
    ########################  测试阶段 ################################
    if args.eval:
        main_eval(args, create_shared_model, init_agent)
        return


####################### 训练阶段 #################################
    start_time = time.time()
    local_start_time_str = time.strftime("%Y-%m-%d_%H:%M:%S",
                                         time.localtime(start_time))
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    random.seed(args.seed)

    # 设置日志参数
    if args.log_dir is not None:
        tb_log_dir = args.log_dir + "/" + args.title + "-" + local_start_time_str
        log_writer = SummaryWriter(log_dir=tb_log_dir)
    else:
        log_writer = SummaryWriter(comment=args.title)

    if args.gpu_ids == -1:
        args.gpu_ids = [-1]
    else:
        torch.cuda.manual_seed(args.seed)
        mp.set_start_method("spawn")

    # 创建一个 torch.nn.Module的子类对象
    shared_model = create_shared_model(args)

    optimizer = optimizer_type(
        filter(lambda p: p.requires_grad, shared_model.parameters()), args)
    # 加载预先保存的模型
    train_total_ep, n_frames = load_checkpoint(args, shared_model, optimizer)
    # TODO: delete this after debug
    # train_total_ep = 1000001

    if shared_model is not None:
        # 模型在多进程间共享参数 这个参数是torch.mutiprocessing 调用fork之前必须调用的方法
        shared_model.share_memory()
        # 创建一个 torch.optim.Optimizer的子类对象
        # filter 函数把model中所有需要梯度更新的变量 作为参数送到optimizer的constructor中

        optimizer.share_memory()
        print(shared_model)
    else:
        assert (args.agent_type == "RandomNavigationAgent"
                ), "The model is None but agent is not random agent"
        optimizer = None

    processes = []

    end_flag = mp.Value(ctypes.c_bool, False)
    global_ep = mp.Value(ctypes.c_int)

    global_ep.value = train_total_ep

    # 多进程共享资源队列
    train_res_queue = mp.Queue()
    # 创建多进程
    # target 进程执行目标函数
    #
    for rank in range(0, args.workers):
        p = mp.Process(
            target=target,
            args=(rank, args, create_shared_model, shared_model, init_agent,
                  optimizer, train_res_queue, end_flag, global_ep),
        )
        p.start()
        processes.append(p)
        time.sleep(0.1)

    print("Train agents created.")

    train_thin = args.train_thin
    train_scalars = ScalarMeanTracker()

    # 主线程
    try:
        while train_total_ep < args.max_ep:

            train_result = train_res_queue.get()
            train_scalars.add_scalars(train_result)
            train_total_ep += 1
            global_ep.value = train_total_ep

            n_frames += train_result["ep_length"]
            if (train_total_ep % train_thin) == 0:
                log_writer.add_scalar("n_frames", n_frames, train_total_ep)
                tracked_means = train_scalars.pop_and_reset()
                for k in tracked_means:
                    log_writer.add_scalar(k + "/train", tracked_means[k],
                                          train_total_ep)

            if (train_total_ep % args.ep_save_freq) == 0:

                print(n_frames)
                if not os.path.exists(args.save_model_dir):
                    os.makedirs(args.save_model_dir)
                state_to_save = shared_model.state_dict()
                save_path = os.path.join(
                    args.save_model_dir,
                    "{0}_{1}_{2}_{3}.dat".format(args.title, n_frames,
                                                 train_total_ep,
                                                 local_start_time_str),
                )
                torch.save(state_to_save, save_path)

            if (train_total_ep % args.ep_save_ckpt) == 0:
                print("save check point at episode {}".format(train_total_ep))
                checkpoint = {
                    'train_total_ep': train_total_ep,
                    'n_frames': n_frames,
                    'shared_model': shared_model.state_dict(),
                    'optimizer': optimizer.state_dict()
                }
                checkpoint_path = os.path.join(args.save_model_dir,
                                               "checkpoint.dat")
                torch.save(checkpoint, checkpoint_path)

    finally:
        log_writer.close()
        end_flag.value = True
        for p in processes:
            time.sleep(0.1)
            p.join()
Exemplo n.º 4
0
def main():
    print('Starting.')

    setproctitle.setproctitle('A3C Manager')
    args = flag_parser.parse_arguments()

    create_shared_model = model.Model
    init_agent = agent.A3CAgent
    optimizer_type = optimizer_class(args.optimizer)

    start_time = time.time()
    local_start_time_str = \
        time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime(start_time))

    # Seed sources of randomness.
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    random.seed(args.seed)

    if args.enable_logging:
        from tensorboardX import SummaryWriter
        log_dir = 'runs/' + args.title + '-' + local_start_time_str
        log_writer = SummaryWriter(log_dir=log_dir)

    if args.gpu_ids == -1:
        args.gpu_ids = [-1]
    else:
        torch.cuda.manual_seed(args.seed)
        mp.set_start_method('spawn', force=True)

    print('=> Creating the shared model and optimizer.')
    shared_model = create_shared_model(args)

    shared_model.share_memory()

    if (args.resume):
        shared_model.load_state_dict(torch.load('./models/last_model'))
    elif (args.load_model != ''):
        shared_model.load_state_dict(torch.load(args.load_model))
    else:
        print("NO MODEL SUPPLIED")
        return

    print('=> Creating the agents.')
    processes = []

    end_flag = mp.Value(ctypes.c_bool, False)

    ## TEST ##
    if (args.num_test_episodes == 0):
        return
    print("Testing...")
    # Turn on random initialization for testing
    args.randomize_objects = True
    end_flag.value = False
    test_res_queue = mp.Queue()
    for rank in range(0, args.workers):
        p = mp.Process(target=train.test,
                       args=(rank, args, create_shared_model, shared_model,
                             init_agent, test_res_queue, end_flag))
        p.start()
        processes.append(p)
        print('* Agent created.')
        time.sleep(0.1)

    test_total_ep = 0
    n_frames = 0

    test_thin = args.test_thin
    test_scalars = ScalarMeanTracker()

    try:
        while test_total_ep < args.num_test_episodes:
            test_result = test_res_queue.get()
            test_scalars.add_scalars(test_result)
            test_total_ep += 1
            n_frames += test_result["ep_length"]
            if args.enable_logging and test_total_ep % test_thin == 0:
                log_writer.add_scalar("n_frames", n_frames, test_total_ep)
                tracked_means = test_scalars.pop_and_reset()
                for k in tracked_means:
                    log_writer.add_scalar(k + "/test", tracked_means[k],
                                          test_total_ep)

    finally:
        if args.enable_logging:
            log_writer.close()
        end_flag.value = True
        for p in processes:
            time.sleep(0.1)
            p.join()
Exemplo n.º 5
0
def main():
    print('Starting.')

    setproctitle.setproctitle('A3C Manager')
    args = flag_parser.parse_arguments()

    create_shared_model = model.Model
    init_agent = agent.A3CAgent
    optimizer_type = optimizer_class(args.optimizer)

    start_time = time.time()
    local_start_time_str = \
        time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime(start_time))

    # Seed sources of randomness.
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    random.seed(args.seed)

    if args.enable_logging:
        from tensorboardX import SummaryWriter
        log_dir = 'runs/' + args.prepend_log + args.title + '-' + local_start_time_str
        log_writer = SummaryWriter(log_dir=log_dir)

    if args.gpu_ids == -1:
        args.gpu_ids = [-1]
    else:
        torch.cuda.manual_seed(args.seed)
        mp.set_start_method('spawn', force=True)

    print('=> Creating the shared model and optimizer.')
    shared_model = create_shared_model(args)

    shared_model.share_memory()
    optimizer = optimizer_type(
        filter(lambda p: p.requires_grad, shared_model.parameters()), args)
    optimizer.share_memory()

    if (args.resume):
        shared_model.load_state_dict(
            torch.load('./models/{}_last_model'.format(args.prepend_log)))
    elif (args.load_model != ''):
        shared_model.load_state_dict(torch.load(args.load_model))

    print('=> Creating the agents.')
    processes = []

    end_flag = mp.Value(ctypes.c_bool, False)

    train_res_queue = mp.Queue()
    for rank in range(0, args.workers):
        p = mp.Process(target=train.train,
                       args=(rank, args, create_shared_model, shared_model,
                             init_agent, optimizer, train_res_queue, end_flag))
        p.start()
        processes.append(p)
        print('* Agent created.')
        time.sleep(0.1)

    train_total_ep = 0
    n_frames = 0

    train_thin = args.train_thin
    train_scalars = ScalarMeanTracker()

    success_tracker = []

    try:
        while train_total_ep < args.num_train_episodes:
            train_result = train_res_queue.get()
            train_scalars.add_scalars(train_result)
            train_total_ep += 1
            n_frames += train_result["ep_length"]
            if train_total_ep % 100 == 0:
                torch.save(
                    shared_model.state_dict(),
                    './models/{}_model_{}'.format(args.prepend_log,
                                                  train_total_ep))
            if args.enable_logging and train_total_ep % train_thin == 0:
                log_writer.add_scalar("n_frames", n_frames, train_total_ep)
                tracked_means = train_scalars.pop_and_reset()
                for k in tracked_means:
                    log_writer.add_scalar(k + "/train", tracked_means[k],
                                          train_total_ep)
            success_tracker.append(train_result["success"])
            if len(success_tracker) > 100:
                success_tracker.pop(0)
            if len(success_tracker) >= 100 and sum(success_tracker) / len(
                    success_tracker) > args.train_threshold:
                break
    finally:
        if args.enable_logging:
            log_writer.close()
        end_flag.value = True
        for p in processes:
            time.sleep(0.1)
            p.join()

    torch.save(shared_model.state_dict(),
               './models/{}_last_model'.format(args.prepend_log))
Exemplo n.º 6
0
def main():
    setproctitle.setproctitle("Train/Test Manager")
    args = flag_parser.parse_arguments()

    if args.model == "BaseModel" or args.model == "GCN":
        args.learned_loss = False
        args.num_steps = 50
        target = nonadaptivea3c_val if args.eval else nonadaptivea3c_train
    else:
        args.learned_loss = True
        args.num_steps = 6
        target = savn_val if args.eval else savn_train

    create_shared_model = model_class(args.model)
    init_agent = agent_class(args.agent_type)
    optimizer_type = optimizer_class(args.optimizer)

    if args.eval:
        main_eval(args, create_shared_model, init_agent)
        return

    start_time = time.time()
    local_start_time_str = time.strftime("%Y-%m-%d_%H:%M:%S",
                                         time.localtime(start_time))
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    random.seed(args.seed)

    if args.log_dir is not None:
        tb_log_dir = args.log_dir + "/" + args.title + "-" + local_start_time_str
        log_writer = SummaryWriter(log_dir=tb_log_dir)
    else:
        log_writer = SummaryWriter(comment=args.title)

    if args.gpu_ids == -1:
        args.gpu_ids = [-1]
    else:
        torch.cuda.manual_seed(args.seed)
        mp.set_start_method("spawn")

    shared_model = create_shared_model(args)

    train_total_ep = 0
    n_frames = 0

    if shared_model is not None:
        shared_model.share_memory()
        optimizer = optimizer_type(
            filter(lambda p: p.requires_grad, shared_model.parameters()), args)
        optimizer.share_memory()
        print(shared_model)
    else:
        assert (args.agent_type == "RandomNavigationAgent"
                ), "The model is None but agent is not random agent"
        optimizer = None

    processes = []

    print('Start Loading!')
    optimal_action_path = './data/AI2thor_Combine_Dataset/Optimal_Path_Combine.json'
    with open(optimal_action_path, 'r') as read_file:
        optimal_action_dict = json.load(read_file)
    manager = Manager()
    optimal_action = manager.dict()
    optimal_action.update(optimal_action_dict)
    glove_file_path = './data/AI2thor_Combine_Dataset/det_feature_512_train.hdf5'
    glove_file = hdf5_to_dict(glove_file_path)
    # det_gt_path = './data/AI2thor_Combine_Dataset/Instance_Detection_Combine.pkl'
    # with open(det_gt_path, 'rb') as read_file:
    #     det_gt = pickle.load(read_file)
    print('Loading Success!')

    end_flag = mp.Value(ctypes.c_bool, False)

    train_res_queue = mp.Queue()

    for rank in range(0, args.workers):
        p = mp.Process(
            target=target,
            args=(
                rank,
                args,
                create_shared_model,
                shared_model,
                init_agent,
                optimizer,
                train_res_queue,
                end_flag,
                glove_file,
                optimal_action,
                # det_gt,
            ),
        )
        p.start()
        processes.append(p)
        time.sleep(0.1)

    print("Train agents created.")

    train_thin = args.train_thin
    train_scalars = ScalarMeanTracker()

    # start_ep_time = time.time()

    try:
        while train_total_ep < args.max_ep:

            train_result = train_res_queue.get()
            train_scalars.add_scalars(train_result)
            train_total_ep += 1
            n_frames += train_result["ep_length"]
            # if train_total_ep % 10 == 0:
            #     print(n_frames / train_total_ep)
            #     print((time.time() - start_ep_time) / train_total_ep)
            if (train_total_ep % train_thin) == 0:
                log_writer.add_scalar("n_frames", n_frames, train_total_ep)
                tracked_means = train_scalars.pop_and_reset()
                for k in tracked_means:
                    log_writer.add_scalar(k + "/train", tracked_means[k],
                                          train_total_ep)

            if (train_total_ep % args.ep_save_freq) == 0:

                print(n_frames)
                if not os.path.exists(args.save_model_dir):
                    os.makedirs(args.save_model_dir)
                state_to_save = shared_model.state_dict()
                save_path = os.path.join(
                    args.save_model_dir,
                    "{0}_{1}_{2}_{3}.dat".format(args.title, n_frames,
                                                 train_total_ep,
                                                 local_start_time_str),
                )
                torch.save(state_to_save, save_path)

    finally:
        log_writer.close()
        end_flag.value = True
        for p in processes:
            time.sleep(0.1)
            p.join()
Exemplo n.º 7
0
def main():
    setproctitle.setproctitle("Train/Test Manager")
    args = command_parser.parse_arguments()

    print('Training started from: {}'.format(
        time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())))
    )

    args.learned_loss = False
    args.num_steps = 50
    target = a3c_val if args.eval else a3c_train

    if args.csiro:
        args.data_dir = './data/'
    else:
        check_data(args)
    scenes = loading_scene_list(args)

    create_shared_model = model_class(args.model)
    init_agent = agent_class(args.agent_type)
    optimizer_type = optimizer_class(args.optimizer)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    random.seed(args.seed)

    if args.eval:
        main_eval(args, create_shared_model, init_agent)
        return

    start_time = time.time()
    local_start_time_str = time.strftime(
        '%Y_%m_%d_%H_%M_%S', time.localtime(start_time)
    )

    tb_log_dir = args.log_dir + '/' + args.title + '_' + args.phase + '_' + local_start_time_str
    log_writer = SummaryWriter(log_dir=tb_log_dir)

    if args.gpu_ids == -1:
        args.gpu_ids = [-1]
    else:
        torch.cuda.manual_seed(args.seed)
        mp.set_start_method("spawn")

    shared_model = create_shared_model(args)

    train_total_ep = 0
    n_frames = 0

    if args.continue_training is not None:
        saved_state = torch.load(
            args.continue_training, map_location=lambda storage, loc: storage
        )
        shared_model.load_state_dict(saved_state)

        train_total_ep = int(args.continue_training.split('_')[-7])
        n_frames = int(args.continue_training.split('_')[-8])

    if args.fine_tuning is not None:
        saved_state = torch.load(
            args.fine_tuning, map_location=lambda storage, loc: storage
        )
        model_dict = shared_model.state_dict()
        pretrained_dict = {k: v for k, v in saved_state.items() if (k in model_dict and v.shape == model_dict[k].shape)}
        model_dict.update(pretrained_dict)
        shared_model.load_state_dict(model_dict)

    if args.update_meta_network:
        for layer, parameters in shared_model.named_parameters():
            if not layer.startswith('meta'):
                parameters.requires_grad = False

    shared_model.share_memory()
    if args.fine_tune_graph:
        optimizer = optimizer_type(
            [
                {'params': [v for k, v in shared_model.named_parameters() if v.requires_grad and not k.startswith('graph')],
                 'lr': 0.00001},
                {'params': [v for k, v in shared_model.named_parameters() if v.requires_grad and k.startswith('graph')],
                 'lr': args.lr},
            ]
        )
    else:
        optimizer = optimizer_type(
            [v for k, v in shared_model.named_parameters() if v.requires_grad], lr=args.lr
        )
    optimizer.share_memory()
    print(shared_model)

    processes = []

    end_flag = mp.Value(ctypes.c_bool, False)
    train_res_queue = mp.Queue()

    for rank in range(0, args.workers):
        p = mp.Process(
            target=target,
            args=(
                rank,
                args,
                create_shared_model,
                shared_model,
                init_agent,
                optimizer,
                train_res_queue,
                end_flag,
                scenes,
            ),
        )
        p.start()
        processes.append(p)
        time.sleep(0.1)

    print("Train agents created.")

    train_thin = args.train_thin
    train_scalars = ScalarMeanTracker()

    try:
        while train_total_ep < args.max_ep:

            train_result = train_res_queue.get()
            train_scalars.add_scalars(train_result)
            train_total_ep += 1
            n_frames += train_result['ep_length']

            if (train_total_ep % train_thin) == 0:
                log_writer.add_scalar('n_frames', n_frames, train_total_ep)
                tracked_means = train_scalars.pop_and_reset()
                for k in tracked_means:
                    log_writer.add_scalar(
                        k + '/train', tracked_means[k], train_total_ep
                    )

            if (train_total_ep % args.ep_save_freq) == 0:

                print('{}: {}'.format(train_total_ep, n_frames))
                if not os.path.exists(args.save_model_dir):
                    os.makedirs(args.save_model_dir)
                state_to_save = shared_model.state_dict()
                save_path = os.path.join(
                    args.save_model_dir,
                    '{0}_{1}_{2}_{3}.dat'.format(
                        args.title, n_frames, train_total_ep, local_start_time_str
                    ),
                )
                torch.save(state_to_save, save_path)

    finally:
        log_writer.close()
        end_flag.value = True
        for p in processes:
            time.sleep(0.1)
            p.join()

    if args.test_after_train:
        full_eval()