コード例 #1
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-modelType',
                        default=4,
                        type=int,
                        help='Refer train_utils.py ')
    parser.add_argument('-numSpkrs',
                        default=7323,
                        type=int,
                        help='Number of output labels for model')
    parser.add_argument('modelDirectory',
                        help='Directory containing the model checkpoints')
    parser.add_argument(
        'featDir', help='Directory containing features ready for extraction')
    parser.add_argument('embeddingDir', help='Output directory')
    args = parser.parse_args()

    modelFile = max(glob.glob(args.modelDirectory + '/*'),
                    key=os.path.getctime)
    # Load model definition
    if args.modelType == 3:
        net = simpleTDNN(args.numSpkrs, p_dropout=0)
    else:
        net = xvecTDNN(args.numSpkrs, p_dropout=0)

    checkpoint = torch.load(modelFile, map_location=torch.device('cuda'))
    new_state_dict = OrderedDict()
    for k, v in checkpoint['model_state_dict'].items():
        if k.startswith('module.'):
            new_state_dict[k[7:]] = v  # ugly fix to remove 'module' from key
        else:
            new_state_dict[k] = v

    # load trained weights
    net.load_state_dict(new_state_dict)
    net = net.cuda()
    net.eval()

    # Parallel Processing
    try:
        nSplits = int(
            sorted(glob.glob(args.featDir + '/split*'),
                   key=getSplitNum)[-1].split('/')[-1].lstrip('split'))
    except:
        print('Cannot find %s/splitN directory' % args.featDir)
        sys.exit(1)

    if not os.path.isdir(args.embeddingDir):
        os.makedirs(args.embeddingDir)
    nProcs = nSplits
    L = [('%s/split%d/%d/feats.scp' % (args.featDir, nSplits, i),
          '%s/xvector.%d.ark' % (args.embeddingDir, i),
          '%s/xvector.%d.scp' % (args.embeddingDir, i), net, 'fc1')
         for i in range(1, nSplits + 1)]
    pool2 = Pool(processes=nProcs)
    result = pool2.starmap(par_core_extractXvectors, L)
    pool2.terminate()

    os.system('cat %s/xvector.*.scp > %s/xvector.scp' %
              (args.embeddingDir, args.embeddingDir))
コード例 #2
0
ファイル: make_NN_vs_other.py プロジェクト: hsdtlx/shadow_sim
def run_main():
    deck_lists = list(map(int,args.decklists.split(","))) if args.decklists is not None else None
    if deck_lists is None:
        deck_lists = list(deck_id_2_name.keys())
    else:
        assert all(key in deck_id_2_name for key in deck_lists)
    if deck_lists == [0, 1, 4, 5, 10, 12]:
        deck_lists = [0, 1, 4, 12, 5, 10]
    mylogger.info("deck_lists:{}".format(deck_lists))
    D = [Deck() for i in range(len(deck_lists))]
    deck_index = 0
    # sorted_keys = sorted(list(deck_id_2_name.keys()))
    # for i in sorted_keys:
    #    if i not in deck_lists:
    #        continue
    for i in deck_lists:
        mylogger.info("{}(deck_id:{}):{}".format(deck_index, i, key_2_tsv_name[i]))
        D[deck_index] = tsv_to_deck(key_2_tsv_name[i][0])
        D[deck_index].set_leader_class(key_2_tsv_name[i][1])
        deck_index += 1
    Results = {}
    list_range = range(len(deck_lists))
    #print(list(itertools.product(list_range,list_range)))

    Player1 = Player(9, True, policy=New_Dual_NN_Non_Rollout_OM_ISMCTSPolicy(model_name=args.model_name),
                     mulligan=Min_cost_mulligan_policy())

    if args.opponent is not None:
        if args.model_name is not None:
            if args.opponent == "Greedy":
                Player2 = Player(9, True, policy=NN_GreedyPolicy(model_name=args.model_name),
                                 mulligan=Min_cost_mulligan_policy())
            elif args.opponent == "MCTS":
                Player2 = Player(9, True, policy=New_Dual_NN_Non_Rollout_OM_ISMCTSPolicy(model_name=args.model_name),
                                 mulligan=Min_cost_mulligan_policy())
        else:
            Player2 = Player(9, True, policy=Opponent_Modeling_MCTSPolicy(),
                             mulligan=Min_cost_mulligan_policy())
    else:
        Player2 = Player(9, True, policy=AggroPolicy(),
             mulligan=Min_cost_mulligan_policy())

    Player1.name = "Alice"
    Player2.name = "Bob"
    iteration = int(args.iteration) if args.iteration is not None else 10
    deck_list_len = len(deck_lists)
    iter_data = [(deck_list_len*i+j,Player1,
                   Player2,(i,j),(deck_lists[i],deck_lists[j]),iteration) for i,j in itertools.product(list_range,list_range)]
    pool = Pool(3)  # 最大プロセス数:8
    # memory = pool.map(preparation, iter_data)
    result = pool.map(multi_battle, iter_data)
    #result = list(tqdm(result, total=len(list_range)**2))
    pool.close()  # add this.
    pool.terminate()  # add this.
    for data in result:
        Results[data[0]] = data[1]
    print(Results)
コード例 #3
0
 def create_episodes(
     self,
     n_episodes: int,
     n_processes: int,
     mcts_samples: int,
     mcts_temp: float,
     mcts_cpuct: int,
     mcts_observation_weight: float,
     model: Model,
 ) -> List[Tuple[List[ObservationType], List[np.ndarray], int, Summary]]:
     pool = Pool(n_processes)
     res = pool.starmap(
         self._generator.perform_episode,
         [[mcts_samples, mcts_temp, mcts_cpuct, mcts_observation_weight, model]]
         * n_episodes,
     )
     pool.close()
     pool.terminate()
     pool.join()
     return res
コード例 #4
0
ファイル: multi_test.py プロジェクト: AutumnCrocus/shadow_sim
def check_score():
    print(args)
    p_size = cpu_num
    print("use cpu num:{}".format(p_size))

    loss_history = []
    check_deck_id = int(
        args.check_deck_id) if args.check_deck_id is not None else None
    cuda_flg = args.cuda == "True"
    #node_num = int(args.node_num)
    #net = New_Dual_Net(node_num)
    model_name = args.model_name
    existed_output_list = os.listdir(path="./Battle_Result")
    existed_output_list = [
        f for f in existed_output_list
        if os.path.isfile(os.path.join("./Battle_Result", f))
    ]
    result_name = "{}:{}".format(model_name.split(".")[0], args.deck_list)
    same_name_count = len(
        [1 for cell in existed_output_list if result_name in cell])
    print("same_name_count:", same_name_count)
    result_name += "_{:0>3}".format(same_name_count + 1)
    PATH = 'model/' + model_name
    model_dict = torch.load(PATH)
    n_size = model_dict["final_layer.weight"].size()[1]
    net = New_Dual_Net(n_size, hidden_num=args.hidden_num[0])
    net.load_state_dict(model_dict)
    opponent_net = None
    if args.opponent_model_name is not None:
        # opponent_net = New_Dual_Net(node_num)
        o_model_name = args.opponent_model_name
        PATH = 'model/' + o_model_name
        model_dict = torch.load(PATH)
        n_size = model_dict["final_layer.weight"].size()[1]
        opponent_net = New_Dual_Net(n_size, hidden_num=args.hidden_num[1])
        opponent_net.load_state_dict(model_dict)

    if torch.cuda.is_available() and cuda_flg:
        net = net.cuda()
        opponent_net = opponent_net.cuda(
        ) if opponent_net is not None else None
        print("cuda is available.")
    #net.zero_grad()
    deck_sampling_type = False
    if args.deck is not None:
        deck_sampling_type = True
    G = Game()
    net.cpu()
    t3 = datetime.datetime.now()
    if args.greedy_mode is not None:
        p1 = Player(9, True, policy=Dual_NN_GreedyPolicy(origin_model=net))
    else:
        p1 = Player(9,
                    True,
                    policy=New_Dual_NN_Non_Rollout_OM_ISMCTSPolicy(
                        origin_model=net,
                        cuda=cuda_flg,
                        iteration=args.step_iter),
                    mulligan=Min_cost_mulligan_policy())
    #p1 = Player(9, True, policy=AggroPolicy())
    p1.name = "Alice"
    if fixed_opponent is not None:
        if fixed_opponent == "Aggro":
            p2 = Player(9,
                        False,
                        policy=AggroPolicy(),
                        mulligan=Min_cost_mulligan_policy())
        elif fixed_opponent == "OM":
            p2 = Player(9, False, policy=Opponent_Modeling_ISMCTSPolicy())
        elif fixed_opponent == "NR_OM":
            p2 = Player(9,
                        False,
                        policy=Non_Rollout_OM_ISMCTSPolicy(iteration=200),
                        mulligan=Min_cost_mulligan_policy())
        elif fixed_opponent == "ExItGreedy":
            tmp = opponent_net if opponent_net is not None else net
            p2 = Player(9,
                        False,
                        policy=Dual_NN_GreedyPolicy(origin_model=tmp))
        elif fixed_opponent == "Greedy":
            p2 = Player(9,
                        False,
                        policy=New_GreedyPolicy(),
                        mulligan=Simple_mulligan_policy())
        elif fixed_opponent == "Random":
            p2 = Player(9,
                        False,
                        policy=RandomPolicy(),
                        mulligan=Simple_mulligan_policy())
    else:
        assert opponent_net is not None
        p2 = Player(9,
                    False,
                    policy=New_Dual_NN_Non_Rollout_OM_ISMCTSPolicy(
                        origin_model=opponent_net, cuda=cuda_flg),
                    mulligan=Min_cost_mulligan_policy())
    # p2 = Player(9, False, policy=RandomPolicy(), mulligan=Min_cost_mulligan_policy())
    p2.name = "Bob"
    Battle_Result = {}
    deck_list = tuple(map(int, args.deck_list.split(",")))
    print(deck_list)
    test_deck_list = deck_list  # (0,1,4,10,13)
    test_deck_list = tuple(itertools.product(test_deck_list, test_deck_list))
    test_episode_len = evaluate_num  #100
    episode_num = evaluate_num
    match_num = len(test_deck_list)
    manager = Manager()
    shared_array = manager.Array("i",
                                 [0 for _ in range(3 * len(test_deck_list))])
    iter_data = [(p1, p2, shared_array, episode_num, p_id, test_deck_list)
                 for p_id in range(p_size)]
    freeze_support()
    p1_name = p1.policy.name.replace("origin", args.model_name)
    if args.opponent_model_name is not None:
        p2_name = p2.policy.name.replace("origin", args.opponent_model_name)
    else:
        p2_name = p2.policy.name.replace("origin", args.model_name)
    print(p1_name)
    print(p2_name)
    pool = Pool(p_size, initializer=tqdm.set_lock,
                initargs=(RLock(), ))  # 最大プロセス数:8
    memory = pool.map(multi_battle, iter_data)
    pool.close()  # add this.
    pool.terminate()  # add this.
    print("\n" * (match_num + 1))
    memory = list(memory)
    min_WR = 1.0
    Battle_Result = {(deck_id[0], deck_id[1]): \
                         tuple(shared_array[3*index+1:3*index+3]) for index, deck_id in enumerate(test_deck_list)}
    print(shared_array)
    txt_dict = {}
    for key in sorted(list((Battle_Result.keys()))):
        cell = "{}:WR:{:.2%},first_WR:{:.2%}"\
              .format(key,Battle_Result[key][0]/test_episode_len,2*Battle_Result[key][1]/test_episode_len)
        print(cell)
        txt_dict[key] = cell
    print(Battle_Result)
    #     result_name = "{}:{}_{}".format(model_name.split(".")[0],args.deck_list,)
    #     result_name = model_name.split(".")[0] + ":" + args.deck_list + ""
    deck_num = len(deck_list)
    # os.makedirs("Battle_Result", exist_ok=True)
    with open("Battle_Result/" + result_name, "w") as f:
        writer = csv.writer(f, delimiter='\t', lineterminator='\n')
        row = ["{} vs {}".format(p1_name, p2_name)]
        deck_names = [deck_id_2_name[deck_list[i]] for i in range(deck_num)]
        row = row + deck_names
        writer.writerow(row)
        for i in deck_list:
            row = [deck_id_2_name[i]]
            for j in deck_list:
                row.append(Battle_Result[(i, j)])
            writer.writerow(row)
        for key in list(txt_dict.keys()):
            writer.writerow([txt_dict[key]])
コード例 #5
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-modelType',
                        default='xvecTDNN',
                        help='Refer train_utils.py ')
    parser.add_argument('-numSpkrs',
                        default=7323,
                        type=int,
                        help='Number of output labels for model')
    parser.add_argument('-layerName',
                        default='fc1',
                        help="DNN layer for embeddings")
    parser.add_argument('modelDirectory',
                        help='Directory containing the model checkpoints')
    parser.add_argument(
        'featDir', help='Directory containing features ready for extraction')
    parser.add_argument('embeddingDir', help='Output directory')
    args = parser.parse_args()

    try:
        modelFile = max(glob.glob(args.modelDirectory + '/*'),
                        key=os.path.getctime)
    except ValueError:
        print("[ERROR] No trained model has been found in {}.".format(
            args.modelDirectory))
        sys.exit(1)

    # Load model definition
    net = eval('{}({}, p_dropout=0)'.format(args.modelType, args.numSpkrs))

    checkpoint = torch.load(modelFile, map_location=torch.device('cuda'))
    new_state_dict = OrderedDict()
    if 'relation' in args.modelType:
        checkpoint_dict = checkpoint['encoder_state_dict']
    else:
        checkpoint_dict = checkpoint['model_state_dict']
    for k, v in checkpoint_dict.items():
        if k.startswith('module.'):
            new_state_dict[k[7:]] = v  # ugly fix to remove 'module' from key
        else:
            new_state_dict[k] = v

    # load trained weights
    net.load_state_dict(new_state_dict)
    net = net.cuda()
    net.eval()

    # Parallel Processing
    try:
        nSplits = int(
            sorted(glob.glob(args.featDir + '/split*'),
                   key=getSplitNum)[-1].split('/')[-1].lstrip('split'))
    except ValueError:
        print('[ERROR] Cannot find %s/splitN directory' % args.featDir)
        sys.exit(1)

    if not os.path.isdir(args.embeddingDir):
        os.makedirs(args.embeddingDir)

    print('Extracting xvectors by distributing jobs to pool workers... ')
    nProcs = nSplits
    L = [('%s/split%d/%d/feats.scp' % (args.featDir, nSplits, i),
          '%s/xvector.%d.ark' % (args.embeddingDir, i),
          '%s/xvector.%d.scp' % (args.embeddingDir, i), net, args.layerName)
         for i in range(1, nSplits + 1)]
    pool2 = Pool(processes=nProcs)
    result = pool2.starmap(par_core_extractXvectors, L)
    pool2.terminate()
    print('Multithread job has been finished.')

    print('Writing xvectors to {}'.format(args.embeddingDir))
    os.system('cat %s/xvector.*.scp > %s/xvector.scp' %
              (args.embeddingDir, args.embeddingDir))
コード例 #6
0
        str) + ".jpg"
    return df


path = untar_data(URLs.FOOD)
train_path = path / 'train.txt'
test_path = path / 'test.txt'


def load_data(index):
    train_df = filelist2df(train_path)
    test_df = filelist2df(test_path)
    food = DataBlock(blocks=(ImageBlock, CategoryBlock),
                     get_x=ColReader(1, pref=path / 'images'),
                     splitter=RandomSplitter(),
                     get_y=ColReader(cols=0),
                     item_tfms=Resize(224))
    dls = food.dataloaders(train_df.values, bs=64)


if __name__ == '__main__':
    set_start_method('spawn', force=True)
    try:
        pool = Pool(8)
        pool.map(load_data, [1, 2, 3, 4, 5, 6, 7, 8])
    except KeyboardInterrupt:
        exit()
    finally:
        pool.terminate()
        pool.join()
コード例 #7
0
def ars(env_name,
        policy,
        n_epochs,
        n_workers=8,
        step_size=.02,
        n_delta=32,
        n_top=16,
        exp_noise=0.03,
        zero_policy=True,
        postprocess=postprocess_default):
    torch.autograd.set_grad_enabled(False)
    """
    Augmented Random Search
    https://arxiv.org/pdf/1803.07055

    Args:

    Returns:

    Example:
    """

    pool = Pool(processes=n_workers)
    env = gym.make(env_name)
    W = torch.nn.utils.parameters_to_vector(policy.parameters())
    n_param = W.shape[0]

    if zero_policy:
        W = torch.zeros_like(W)

    r_hist = []
    s_mean = torch.zeros(env.observation_space.shape[0])
    s_stdv = torch.ones(env.observation_space.shape[0])

    total_steps = 0
    exp_dist = torch.distributions.Normal(torch.zeros(n_delta, n_param),
                                          torch.ones(n_delta, n_param))
    do_rollout_partial = partial(do_rollout_train, env_name, policy,
                                 postprocess)

    for _ in range(n_epochs):

        deltas = exp_dist.sample()
        pm_W = torch.cat((W + (deltas * exp_noise), W - (deltas * exp_noise)))

        results = pool.map(do_rollout_partial, pm_W)

        states = torch.empty(0)
        p_returns = []
        m_returns = []
        l_returns = []
        top_returns = []

        for p_result, m_result in zip(results[:n_delta], results[n_delta:]):
            ps, pr, plr = p_result
            ms, mr, mlr = m_result

            states = torch.cat((states, ms, ps), dim=0)
            p_returns.append(pr)
            m_returns.append(mr)
            l_returns.append(plr)
            l_returns.append(mlr)
            top_returns.append(max(pr, mr))

        top_idx = sorted(range(len(top_returns)),
                         key=lambda k: top_returns[k],
                         reverse=True)[:n_top]
        p_returns = torch.stack(p_returns)[top_idx]
        m_returns = torch.stack(m_returns)[top_idx]
        l_returns = torch.stack(l_returns)[top_idx]

        r_hist.append(l_returns.mean())

        ep_steps = states.shape[0]
        s_mean = update_mean(states, s_mean, total_steps)
        s_stdv = update_std(states, s_stdv, total_steps)
        total_steps += ep_steps

        policy.state_means = s_mean
        policy.state_std = s_stdv
        do_rollout_partial = partial(do_rollout_train, env_name, policy,
                                     postprocess)

        W = W + (step_size / (n_delta * torch.cat(
            (p_returns, m_returns)).std() + 1e-6)) * torch.sum(
                (p_returns - m_returns) * deltas[top_idx].T, dim=1)

    pool.terminate()
    torch.nn.utils.vector_to_parameters(W, policy.parameters())
    return policy, r_hist