Esempio n. 1
0
                writer.add_scalar("learner/pi_loss",
                                  loss_pi.detach().item(), t)
                writer.add_scalar("learner/q_loss", loss_q.detach().item(), t)
                writer.add_scalar("learner/alpha_loss",
                                  alpha_loss.detach().item(), t)
                writer.add_scalar("learner/alpha", alpha, t)
                writer.add_scalar("learner/entropy",
                                  entropy.detach().mean().item(), t)

        # CPC update handing
        if args.cpc and e > args.cpc_batch * 2 and e % args.cpc_update_freq == 0:
            for _ in range(args.cpc_update_freq):
                data, indexes, min_len = replay_buffer.sample_traj(
                    args.cpc_batch)
                cpc_optimizer.zero_grad()
                c_hidden = global_cpc.init_hidden(len(data), args.c_dim)
                acc, loss, latents = global_cpc(data, c_hidden)

                # replay_buffer.update_latent(indexes, min_len, latents.detach())
                loss.backward()
                # add gradient clipping
                nn.utils.clip_grad_norm_(global_cpc.parameters(),
                                         max_norm=20,
                                         norm_type=2)
                cpc_optimizer.step()
                writer.add_scalar("learner/cpc_acc", acc, t)
                writer.add_scalar("learner/cpc_loss", loss.detach().item(), t)

        # CPC latent update
        # if args.cpc and e > args.cpc_batch and e % 500 == 0 and e != last_updated:
        #     replay_buffer.create_latents(e=e)
Esempio n. 2
0
        # global_ac.load_state_dict(torch.load(os.path.join(args.save_dir, args.exp_name, args.model_para)))
        load_my_state_dict(
            global_ac,
            os.path.join(args.save_dir, args.exp_name, args.model_para))
        print("load sac model")

    if os.path.exists(os.path.join(args.save_dir, args.exp_name,
                                   args.cpc_para)) and args.cpc:
        global_cpc.load_state_dict(
            torch.load(
                os.path.join(args.save_dir, args.exp_name, args.cpc_para)))
        print("load cpc model")

    o, ep_ret, ep_len = env.reset(), 0, 0
    if args.cpc:
        c_hidden = global_cpc.init_hidden(1, args.c_dim, use_gpu=args.cuda)
        c1, c_hidden = global_cpc.predict(o, c_hidden)
        assert len(c1.shape) == 3
        c1 = c1.flatten().cpu().numpy()
        round_embedding = []
        all_embeddings = []
        meta = []
    trajectory = list()
    p2 = env.p2
    p2_list = [str(p2)]
    discard = False
    uncertainties = []
    glod_input = defaultdict(list)
    glod_target = defaultdict(list)
    wins, scores, win_rate, m_score = [], [], 0, 0
    local_t, local_e = 0, 0