def train(meta_decoder, decoder_optimizer, fclayers_for_hyper_params): global moving_average global moving_average_alpha decoder_hidden = meta_decoder.initHidden() decoder_optimizer.zero_grad() output = torch.zeros([1, 1, meta_decoder.output_size], device=device) softmax = nn.Softmax(dim=1) softmax_outputs_stored = list() loss = 0 # for i in range(3): output, decoder_hidden = meta_decoder(output, decoder_hidden) #print(hyper_params[i]) softmax_outputs_stored.append( softmax(fclayers_for_hyper_params[hyper_params[i][0]](output))) # output_interaction = softmax_outputs_stored[-1] type_of_interaction = Categorical(output_interaction).sample().tolist()[0] if type_of_interaction == 0: # PairwiseEuDist for i in range(3, 4): output, decoder_hidden = meta_decoder(output, decoder_hidden) softmax_outputs_stored.append( softmax(fclayers_for_hyper_params[hyper_params[i][0]](output))) elif type_of_interaction == 1: # PairwiseLog # no hyper-params for this interaction type pass else: # PointwiseMLPCE for i in range(4, 7): output, decoder_hidden = meta_decoder(output, decoder_hidden) softmax_outputs_stored.append( softmax(fclayers_for_hyper_params[hyper_params[i][0]](output))) # resulted_str = [] for outputs in softmax_outputs_stored: print("softmax_outputs: ", outputs) idx = Categorical(outputs).sample() resulted_str.append(idx.tolist()[0]) resulted_str[ 2] = type_of_interaction # the type of interaction has already been sampled before resulted_idx = resulted_str resulted_str = "_".join(map(str, resulted_str)) print("resulted_str: " + resulted_str) # reward = calc_reward_given_descriptor(resulted_str) if moving_average == -19013: moving_average = reward reward = 0.0 else: tmp = reward reward = reward - moving_average moving_average = moving_average_alpha * tmp + ( 1.0 - moving_average_alpha) * moving_average # print("current reward: " + str(reward)) print("current moving average: " + str(moving_average)) expectedReward = 0 for i in range(len(softmax_outputs_stored)): logprob = torch.log(softmax_outputs_stored[i][0][resulted_idx[i]]) expectedReward += logprob * reward loss = -expectedReward print('loss:', loss) # finally, backpropagate the loss according to the policy loss.backward() decoder_optimizer.step()
def train(meta_decoder, decoder_optimizer, fclayers_for_hyper_params): decoder_hidden = meta_decoder.initHidden() decoder_optimizer.zero_grad() output = torch.zeros([1, 1, meta_decoder.output_size], device=device) softmax = nn.Softmax(dim=1) softmax_outputs_stored = list() loss = 0 # for i in range(3): output, decoder_hidden = meta_decoder(output, decoder_hidden) print(hyper_params[i]) print('output:', output.shape) softmax_outputs_stored.append(softmax(fclayers_for_hyper_params[hyper_params[i][0]](output))) print('softmax_outputs_stored:', softmax_outputs_stored) # output_interaction = softmax_outputs_stored[-1] type_of_interaction = Categorical(output_interaction).sample().tolist()[0] if type_of_interaction == 0: # PairwiseEuDist for i in range(3, 4): output, decoder_hidden = meta_decoder(output, decoder_hidden) softmax_outputs_stored.append(softmax(fclayers_for_hyper_params[hyper_params[i][0]](output))) elif type_of_interaction == 1: # PairwiseLog pass else: # PointwiseMLPCE for i in range(4, 7): output, decoder_hidden = meta_decoder(output, decoder_hidden) softmax_outputs_stored.append(softmax(fclayers_for_hyper_params[hyper_params[i][0]](output))) print(len(softmax_outputs_stored)) # resulted_str = [] for outputs in softmax_outputs_stored: print("softmax_outputs: ", outputs) # idx = torch.argmax(outputs) idx = Categorical(outputs).sample() # print('idx:', idx) # resulted_str.append(idx.item()) resulted_str.append(idx.tolist()[0]) resulted_str[2] = type_of_interaction resulted_idx = resulted_str resulted_str = "_".join(map(str, resulted_str)) print("resulted_str:") print(resulted_str) # reward = calc_reward_given_descriptor(resulted_str) print("reward: " + str(reward)) expectedReward = 0 for i in range(len(softmax_outputs_stored)): # print(softmax_outputs_stored[i][0].tolist()) # print(resulted_idx[i]) logprob = torch.log(softmax_outputs_stored[i][0][resulted_idx[i]]) expectedReward += logprob * reward loss = - expectedReward print('loss:', loss) # backpropagate the loss according to the policy loss.backward() decoder_optimizer.step()