Пример #1
0
def choose_model(f_dict, ter_dict):
    tmp_net1 = Fed_Model()
    tmp_net2 = Fed_Model()
    tmp_net1.load_state_dict(f_dict)
    tmp_net2.load_state_dict(ter_dict)

    _, acc_1, _ = evaluate(tmp_net1, G_loss_fun, test_iter, Args)
    _, acc_2, _ = evaluate(tmp_net2, G_loss_fun, test_iter, Args)
    print('F: %.3f' % acc_1, 'TF: %.3f' % acc_2)

    flag = False
    if np.abs(acc_1 - acc_2) < 0.03:
        flag = True
        return ter_dict, flag
    else:
        return f_dict, flag
Пример #2
0
def visualize(args):
    env = BaseEnv(size=args.env_size)
    f_s_a = MLP(env.state_size + env.action_size, args.s_a_hidden_size,
                args.embedding_dim)
    f_s = MLP(env.state_size, args.s_hidden_size, args.embedding_dim)
    agent = Agent(env_size=args.env_size)

    checkpoint = torch.load(args.model_path)
    f_s_a.load_state_dict(checkpoint['s_a_state_dict'])
    f_s.load_state_dict(checkpoint['s_state_dict'])
    f_s_a.eval()
    f_s.eval()
    vis_map = np.zeros(args.env_size)
    dist_map = np.zeros(args.env_size)
    goal = args.goal_pos
    for i in range(vis_map.shape[0]):
        for j in range(vis_map.shape[1]):
            if (i + 1, j + 1) == goal:
                vis_map[i, j] = -1
            else:
                state_feature = agent.get_state_feature((i + 1, j + 1))
                goal_feature = agent.get_state_feature(goal)
                best_action, min_dist = agent.get_best_action(
                    state_feature, f_s_a, f_s, goal_feature)
                vis_map[i, j] = best_action
                dist_map[i, j] = min_dist

    for i in range(vis_map.shape[0]):
        for j in range(vis_map.shape[1]):
            if vis_map[i, j] == 0:
                print("R", end=" ")
            elif vis_map[i, j] == 1:
                print("L", end=" ")
            elif vis_map[i, j] == 2:
                print("U", end=" ")
            elif vis_map[i, j] == 3:
                print("D", end=" ")
            elif vis_map[i, j] == -1:
                print("G", end=" ")
        print("\n")

    for i in range(dist_map.shape[0]):
        for j in range(dist_map.shape[1]):
            print(dist_map[i, j], end=" ")
        print("\n")
Пример #3
0
        for idx in client_id:
            local = LocalUpdate(client_name=idx,
                                c_round=rounds,
                                train_iter=C_iter[idx],
                                test_iter=test_iter,
                                wp_lists=c_lists[idx],
                                args=Args)
            w, wp_lists = local.TFed_train(
                net=copy.deepcopy(G_net).to(Args.device))
            c_lists[idx] = wp_lists
            w_locals.append(copy.deepcopy(w))

            num_samp.append(len(C_iter[idx].dataset))
        # update global weights
        w_glob, ter_glob = ServerUpdate(w_locals, num_samp)

        w_glob, tmp_flag = choose_model(w_glob, ter_glob)
        if tmp_flag:
            num_s2 += 1
            print('S1')

        # reload global network weights
        G_net.load_state_dict(w_glob)

        #verify accuracy on test set
        g_loss, g_acc, g_acc5 = evaluate(G_net, G_loss_fun, test_iter, Args)
        gv_acc.append(g_acc)

        print('Round {:3d}, Global loss {:.3f}, Global Acc {:.3f}'.format(
            rounds, g_loss, g_acc))