示例#1
0
crit = Critic(in_dim, hidden_dim_d, num_layers).to(device)


def weights_init(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
    if isinstance(m, nn.LSTM):
        for name, param in m.named_parameters():
            if 'bias' in name:
                nn.init.constant_(param, 0.0)
            elif 'weight' in name:
                nn.init.xavier_normal_(param)


gen = gen.apply(weights_init)
crit = crit.apply(weights_init)

# OPTIMIZER
lr = 0.0002  # 0.0002
beta_1 = 0.9  # 0.5
beta_2 = 0.999
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr, betas=(beta_1, beta_2))
crit_opt = torch.optim.Adam(crit.parameters(), lr=lr, betas=(beta_1, beta_2))

# LOADING CKPTS
'''
gen_ckpt_path = "/Users/parkerzhao/Desktop/Projects/gans_panel_data/code/" \
                "conditional_tsgan/experiments/exp6/ckpts/generator_epoch_80.pth.tar"
disc_ckpt_path = "/Users/parkerzhao/Desktop/Projects/gans_panel_data/code/" \
                 "conditional_tsgan/experiments/exp6/ckpts/critic_epoch_80.pth.tar"
'''
示例#2
0
    im_ch=IMAGE_CHANNELS,
    latent_dim=NOISE_DIM,
    hidden_dim=HIDDEN_DIM_GEN,
    use_batchnorm=USE_BATCHNORM,
    upsample_mode=UPSAMPLE_MODE,
)
gen = gen.to(device)
critic = Critic(
    im_ch=IMAGE_CHANNELS,
    hidden_dim=HIDDEN_DIM_DISC,
    use_batchnorm=USE_BATCHNORM,
    spectral_norm=SPECTRAL_NORM,
)
critic = critic.to(device)

critic.apply(init_weights)
gen.apply(init_weights)

# configure loss and optimizers
criterion = nn.BCEWithLogitsLoss()
opt_gen = torch.optim.Adam(gen.parameters(), lr=LR, betas=(beta1, beta2))
opt_disc = torch.optim.Adam(critic.parameters(), lr=LR, betas=(beta1, beta2))

# configure tensorboard writer
repo = git.Repo(search_parent_directories=True)
sha = repo.head.object.hexsha[:6]
logdir = f"/home/bishwarup/GAN_experiments/dcgan/{sha}"
writer = SummaryWriter(log_dir=logdir)

# make a fixed noise to see the generator evolve over time on it
fixed_noise = gen_noise(32, NOISE_DIM, device=device)
示例#3
0
class Agent():
    """Agent that plays and learn from experience. Hyper-paramters chosen from paper."""
    def __init__(self,
                 state_size,
                 action_size,
                 max_action,
                 discount=0.99,
                 tau=0.005,
                 policy_noise=0.2,
                 noise_clip=0.5,
                 policy_freq=2):
        """
        Initializes the Agent.
        @Param:
        1. state_size: env.observation_space.shape[0]
        2. action_size: env.action_size.shape[0]
        3. max_action: list of max values that the agent can take, i.e. abs(env.action_space.high)
        4. discount: return rate
        5. tau: soft target update
        6. policy_noise: noise reset level, DDPG uses Ornstein-Uhlenbeck process
        7. noise_clip: sets boundary for noise calculation to prevent from overestimation of Q-values
        8. policy_freq: number of timesteps to update the policy (actor) after
        """
        super(Agent, self).__init__()

        #Actor Network initialization
        self.actor = Actor(state_size, action_size, max_action).to(device)
        self.actor.apply(self.init_weights)
        self.actor_target = copy.deepcopy(
            self.actor)  #loads main model into target model
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),
                                                lr=0.001)

        #Critic Network initialization
        self.critic = Critic(state_size, action_size).to(device)
        self.critic.apply(self.init_weights)
        self.critic_target = copy.deepcopy(
            self.critic)  #loads main model into target model
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),
                                                 lr=0.001)

        self.max_action = max_action
        self.discount = discount
        self.tau = tau
        self.policy_noise = policy_noise
        self.noise_clip = noise_clip
        self.policy_freq = policy_freq
        self.total_it = 0

    def init_weights(self, layer):
        """Xaviar Initialization of weights"""
        if (type(layer) == nn.Linear):
            nn.init.xavier_normal_(layer.weight)
            layer.bias.data.fill_(0.01)

    def select_action(self, state):
        """Selects an automatic epsilon-greedy action based on the policy"""
        state = torch.FloatTensor(state.reshape(1, -1)).to(device)
        return self.actor(state).cpu().data.numpy().flatten()

    def train(self, replay_buffer: ReplayBuffer):
        """Train the Agent"""

        self.total_it += 1

        # Sample replay buffer
        state, action, reward, next_state, done = replay_buffer.sample(
        )  #sample 256 experiences

        with torch.no_grad():
            # Select action according to policy and add clipped noise
            noise = (torch.randn_like(action) * self.policy_noise).clamp(
                -self.noise_clip, self.noise_clip)

            next_action = (
                self.actor_target(next_state) +
                noise  #noise only set in training to prevent from overestimation
            ).clamp(-self.max_action, self.max_action)

            # Compute the target Q value
            target_Q1, target_Q2 = self.critic_target(next_state,
                                                      next_action)  #Q1, Q2
            target_Q = torch.min(target_Q1, target_Q2)
            target_Q = reward + (1 -
                                 done) * self.discount * target_Q  #TD-target

        # Get current Q estimates
        current_Q1, current_Q2 = self.critic(state, action)  #Q1, Q2

        # Compute critic loss using MSE
        critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(
            current_Q2, target_Q)

        # Optimize the critic
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        # Delayed policy updates (DDPG baseline = 1)
        if (self.total_it % self.policy_freq == 0):

            # Compute actor loss
            actor_loss = -self.critic(state, self.actor(state))[0].mean()

            # Optimize the actor
            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            self.actor_optimizer.step()

            # Soft update by updating the frozen target models
            for param, target_param in zip(self.critic.parameters(),
                                           self.critic_target.parameters()):
                target_param.data.copy_(self.tau * param.data +
                                        (1 - self.tau) * target_param.data)

            for param, target_param in zip(self.actor.parameters(),
                                           self.actor_target.parameters()):
                target_param.data.copy_(self.tau * param.data +
                                        (1 - self.tau) * target_param.data)

    def save(self, filename):
        """Saves the Actor Critic local and target models"""
        torch.save(self.critic.state_dict(),
                   "models/checkpoint/" + filename + "_critic")
        torch.save(self.critic_optimizer.state_dict(),
                   "models/checkpoint/" + filename + "_critic_optimizer")

        torch.save(self.actor.state_dict(),
                   "models/checkpoint/" + filename + "_actor")
        torch.save(self.actor_optimizer.state_dict(),
                   "models/checkpoint/" + filename + "_actor_optimizer")

    def load(self, filename):
        """Loads the Actor Critic local and target models"""
        self.critic.load_state_dict(
            torch.load("models/checkpoint/" + filename + "_critic",
                       map_location='cpu'))
        self.critic_optimizer.load_state_dict(
            torch.load("models/checkpoint/" + filename + "_critic_optimizer",
                       map_location='cpu'))  #optional
        self.critic_target = copy.deepcopy(self.critic)

        self.actor.load_state_dict(
            torch.load("models/checkpoint/" + filename + "_actor",
                       map_location='cpu'))
        self.actor_optimizer.load_state_dict(
            torch.load("models/checkpoint/" + filename + "_actor_optimizer",
                       map_location='cpu'))  #optional
        self.actor_target = copy.deepcopy(self.actor)
示例#4
0
# Network, Optimizers, and Weight Initialization
# ----------------------------------------------------

# Discriminator and Generator init
gen = Generator(Z_DIM, IMG_CHANNELS, FEATURES_GEN).to(device)
critic = Critic(IMG_CHANNELS, FEATURES_CRITIC).to(device)
print('\n>>> Network [D]isciminator & [G]enerator Initialized')

# Optimizers Initializationi
opt_gen = optim.RMSprop(gen.parameters(), lr=LEARNING_RATE)
opt_critic = optim.RMSprop(critic.parameters(), lr=LEARNING_RATE)
print('\n>>> Optimizers for [D] & [G] Initialized')

# Initialize weights
gen = gen.apply(weights_init)
critic = critic.apply(weights_init)
print('\n>>> Network weights initialized ')

arch_info = f"""
-------------------------------------------------------------------
Network Architectures :
--------------------<< Generator >>--------------------------------
{gen}
--------------------<< Discriminator >>----------------------------
{critic}
"""

print(arch_info)

# Statistics to be saved
# ----------------------------------------------------