def cal_loss(self, output, target): """ Build yolo loss Arguments: output -- tuple (delta_pred, conf_pred, class_score), output data of the yolo network target -- tuple (iou_target, iou_mask, box_target, box_mask, class_target, class_mask) target label data delta_pred -- Variable of shape (B, H * W * num_anchors, 4), predictions of delta σ(t_x), σ(t_y), σ(t_w), σ(t_h) conf_pred -- Variable of shape (B, H * W * num_anchors, 1), prediction of IoU score σ(t_c) class_score -- Variable of shape (B, H * W * num_anchors, num_classes), prediction of class scores (cls1, cls2 ..) iou_target -- Variable of shape (B, H * W * num_anchors, 1) iou_mask -- Variable of shape (B, H * W * num_anchors, 1) box_target -- Variable of shape (B, H * W * num_anchors, 4) box_mask -- Variable of shape (B, H * W * num_anchors, 1) class_target -- Variable of shape (B, H * W * num_anchors, 1) class_mask -- Variable of shape (B, H * W * num_anchors, 1) Return: loss -- yolo overall multi-task loss """ delta_pred_batch = output[0] conf_pred_batch = output[1] class_score_batch = output[2] iou_target = target[0] iou_mask = target[1] box_target = target[2] box_mask = target[3] class_target = target[4] class_mask = target[5] b, _, num_classes = class_score_batch.size() class_score_batch = class_score_batch.view(-1, num_classes) class_target = class_target.view(-1) class_mask = class_mask.view(-1) # ignore the gradient of noobject's target class_keep = class_mask.nonzero().squeeze(1) class_score_batch_keep = class_score_batch[class_keep, :] class_target_keep = class_target[class_keep] # if cfg.debug: # print(class_score_batch_keep) # print(class_target_keep) # calculate the loss, normalized by batch size. box_loss = 1 / b * 1 * F.mse_loss(delta_pred_batch * box_mask, box_target * box_mask, reduction='sum') / 2.0 iou_loss = 1 / b * F.mse_loss(conf_pred_batch * iou_mask, iou_target * iou_mask, reduction='sum') / 2.0 class_loss = 1 / b * 1 * F.cross_entropy( class_score_batch_keep, class_target_keep, reduction='sum') return box_loss, iou_loss, class_loss
def forward(self, input, target): return F.mse_loss(input[0:3], target[0:3], size_average=self.size_average, reduce=self.reduce) + 100 * F.mse_loss( input[3:6], target[3:6], size_average=self.size_average, reduce=self.reduce)
def train_minibatch(self, minibatch): sellf.net_optim.zero_grad() rewards = torch.from_numpy( np.zeros((minibatch.shape[0], 1)).astype("float32")) search_probas = torch.from_numpy( np.zeros((len(minibatch), minibatch[0]["search_probas"].shape)).astype("float32")) states = torch.from_numpy( np.zeros((len(minibatch), minibatch[0]["state"].shape)).astype("float32")) for i, memory in enumerate(minibatch): rewards[i] = memory["reward"] search_probas[i] = memory["search_probas"] states[i] = memory["state"] rewards = self.V(rewards) search_probas = self.V(search_probas) states = self.V(states) policies, values = self.net(states) value_loss = F.mse_loss(values, rewards) policy_loss = 0 for search_p, pi in zip(search_probas, policies): search_p.unsqueeze(0) pi.unsqueeze(-1) policy_loss += search_p.mm(pi) total_loss = value_loss + policy_loss total_loss.backward() self.net_optim.step()
def compute(self, this_state, next_state, action): """ Pass input of the form state: [prop, tac, audio] """ this_state = list(map(torch.cat, this_state)) next_state = list(map(torch.cat, next_state)) predicted_states = self.forward(this_state, action) loss = 0 for i, pred_s in enumerate(predicted_states): loss += F.mse_loss(pred_s, next_state[i]) loss.backward() self.opt.step() return loss
def forward(self, recon_x, x, mu, logvar, batch_size, img_size, nc): # recon_x: image reconstructions # x: images # mu and logvar: outputs of your encoder # batch_size: batch_size # img_size: width, respectively height of you images # nc: number of image channels MSE = F.mse_loss(recon_x, x.view(-1, img_size * img_size * nc)) # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) # Normalize KLD /= batch_size * img_size * img_size * nc return MSE + KLD
def train(self): """ Trains the actor and critic networks using PPO """ print("Training PPO for {} steps".format(self.total_train_steps)) num_time_steps = 0 num_train_iterations = 0 while num_time_steps < self.total_train_steps: # Sample policy for transitions self.sample_policy() state, action, reward, next_state, done, action_log_prob = self.sample_transitions( return_type="torch_tensor") # Compute advantage and value returns = self.rewards_to_go(reward, done) value = self.critic(state, action) advantage = returns - value.detach() # Normalize advantage advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-10) # Update actor and critic networks for train_epoch in range(self.num_train_epochs): # Compute value and action prob ratio value = self.critic( state, action ) # TODO: See if you can just resuse value outside for loop action, cur_action_log_prob = self.actor.sample_action() action_prob_ratio = torch.exp(cur_action_log_prob - action_log_prob) # Compute actor loss loss1 = action_prob_ratio * advantage loss2 = torch.clamp(action_prob_ratio, 1 - self.clip, 1 + self.clip) * advantage actor_loss = -torch.min(loss1, loss2).mean() # Compute critic loss critic_loss = F.mse_loss(returns, value) # Update actor network self.actor_optim.zero_grad() actor_loss.backward( retain_graph=True ) # retain_graph prevents gradients from being erased after backprop self.actor_optim.step() # Update critic network self.critic_optim.zero_grad() critic_loss.backward() self.critic_optim.step() self.buffer.clear() num_time_steps += len(state)
def compute(self, trans): """ Pass input of the form state: [prop, tac, audio] """ this_state = [trans.prop, trans.tac, trans.audio] next_state = [trans.prop_next, trans.tac_next, trans.audio_next] this_state = list( map(lambda x: x.to(device), map(torch.cat, this_state))) next_state = list( map(lambda x: x.to(device), map(torch.cat, next_state))) predicted_states = self.forward(this_state, action) loss = 0 for i, pred_s in enumerate(predicted_states): loss += F.mse_loss(pred_s, next_state[i]) loss.backward() self.opt.step() return loss
def learn(self): if self.memory.mem_cntr < self.batch_size: return state, action, reward, new_state, done = self.memory.sample_buffer(self.batch_size) state = T.tensor(state, dtype=T.float).to(self.critic_1.device) action = T.tensor(action, dtype=T.float).to(self.critic_1.device) reward = T.tensor(reward, dtype=T.float).to(self.critic_1.device) state_ = T.tensor(state_, dtype=T.float).to(self.critic_1.device) done = T.tensor(done).to(self.critic_1.device) # passing states and new states through value and target value networks # collapsing along batch dimension since we don't need 2d tensor for scalar quantities value = self.value(state).view(-1) value_ = self.target_value(state_).view(-1) value_[done] = 0.0 # setting terminal states to 0 # pass current states through current policy get action & log prob values actions, log_probs = self.actor.sample_normal(state, reparameterize=False) log_probs = log_probs.view(-1) # critic values for current policy state action pairs q1_new_policy = self.critic_1.forward(state, actions) q2_new_policy = self.critic_2.forward(state, actions) # take critic min and collapse critic_value = T.min(q1_new_policy, q2_new_policy) critic_value = critic_value.view(-1) self.value.optimizer.zero_grad() value_target = critic_value - log_probs value_loss = 0.5 * F.mse_loss(value, value_target) value_loss.backward(retain_graph=True) self.value.optimizer.step() # actor loss (using reparam trick) actions, log_probs = self.actor.sample_normal(state, reparameterize=True) log_probs = log_probs.view(-1) # take critic min for new policy and collapse q1_new_policy = self.critic_1.forward(state, action) q2_new_policy = self.critic_2.forward(state, action) critic_value = T.min(q1_new_policy, q2_new_policy) critic_value = critic_value.view(-1) # calculating actor loss actor_loss = log_probs - critic_value actor_loss = T.mean(actor_loss) self.actor.optimizer.zero_grad() actor_loss.backward(retain_graph=True) self.actor.optimizer.step() q_hat = self.scale * reward + self.gamma*value_ # qhat q1_old_policy = self.critic_1.forward(state, action).view(-1) # old policy (from replay buffer) q2_old_policy = self.critic_2.forward(state, action).view(-1) critic_1_loss = 0.5 * F.mse_loss(q1_old_policy, q_hat) critic_2_loss = 0.5 * F.mse_loss(q2_old_policy, q_hat) self.critic_1.optimizer.zero_grad() self.critic_2.optimizer.zero_grad() critic_loss = critic_1_loss + critic_2_loss critic_loss.backward() self.critic_1.optimizer.step() self.critic_2.optimizer.step() self.update_network_parameters()
def rmse(val, rec): ''' Implement RMSE metric''' return torch.sqrt(F.mse_loss(val, rec))
def forward(self, pred, target): #pred,target(batchsize, 7*7*30) containobj = target[:, :, :, 4] > 0 noobj = target[:, :, :, 4] == 0 containobj = containobj.unsqueeze(-1).expand_as(target) noobj = noobj.unsqueeze(-1).expand_as(target) pred_containobj = pred[containobj].view( -1, 25) #select all grid in a batch that contains an obj target_containobj = target[containobj].view(-1, 25) box_pred = pred_containobj[:, :5] class_pred = pred_containobj[:, 5:] #if obj, box and class box_target = target_containobj[:, :5] class_target = target_containobj[:, 5:] pred_noobj = pred[noobj].view(-1, 25) #select all grids no obj target_noobj = pred[noobj].view(-1, 25) #select all grids no obj in target noobj_mask = torch.ByteTensor(pred_noobj.size()) noobj_mask.zero_() noobj_mask[:, 4] = 1 confidence_pred = pred_noobj[noobj_mask] confidence_target = target_noobj[noobj_mask] noobj_loss = F.mse_loss(confidence_pred, confidence_target, size_average=False) containobj_res_obj = torch.ByteTensor(box_target.size()) containobj_res_obj.zero_() box_iou = torch.zeros(box_target.size()) for i in range(0, box_target.size()[0]): box1_predcontain_obj = box_pred[i].view(-1, 5) box1_xyxy = Variable(torch.FloatTensor( box1_predcontain_obj.size())) box1_xyxy[:, : 2] = box1_predcontain_obj[:, : 2] - 0.5 * self.S * box1_predcontain_obj[:, 2: 4] box1_xyxy[:, 2: 4] = box1_predcontain_obj[:, : 2] + 0.5 * self.S * box1_predcontain_obj[:, 2: 4] box2_targetcontain_obj = box_target[i].view(-1, 5) box2_xyxy = Variable( torch.FloatTensor(box2_targetcontain_obj.size())) box2_xyxy[:, : 2] = box2_targetcontain_obj[:, : 2] - 0.5 * self.S * box2_targetcontain_obj[:, 2: 4] box2_xyxy[:, 2: 4] = box2_targetcontain_obj[:, : 2] + 0.5 * self.S * box2_targetcontain_obj[:, 2: 4] iou = self.compute_iou(box1_xyxy[:, :4], box2_xyxy[:, :4]) containobj_res_obj[i] = 1 box_iou[i, 4] = iou res_predcontain = box_pred[containobj_res_obj].view(-1, 5) res_targetcontain = box_target[containobj_res_obj].view(-1, 5) res_iou = box_iou[containobj_res_obj].view(-1, 5) coordinate_loss = F.mse_loss(res_predcontain[:, :2], res_targetcontain[:, :2], size_average=False) + F.mse_loss( torch.sqrt(res_predcontain[:, 2:4]), torch.sqrt(res_targetcontain[:, 2:4]), size_average=False) res_loss = F.mse_loss(res_predcontain[:, 4], res_iou[:, 4], size_average=False) class_loss = F.mse_loss(class_pred, class_target, size_average=False) return (self.l_coord * coordinate_loss + class_loss + self.l_noobj * noobj_loss + res_loss) / hp.batchsize
def bcq_update(batch, params, nets, optimizer, device=torch.device('cpu'), debug=None, writer=utils.DummyWriter(), learn=False, step=-1): """ :param batch: batch [state, action, reward, next_state] returned by environment. :param params: dict of algorithm parameters. :param nets: dict of networks. :param optimizer: dict of optimizers :param device: torch.device :param debug: dictionary where debug data about actions is saved :param writer: torch.SummaryWriter :param learn: whether to learn on this step (used for testing) :param step: integer step for policy update :return: loss dictionary How parameters should look like:: params = { # algorithm parameters 'gamma' : 0.99, 'soft_tau' : 0.001, 'n_generator_samples': 10, 'perturbator_step' : 30, # learning rates 'perturbator_lr' : 1e-5, 'value_lr' : 1e-5, 'generator_lr' : 1e-3, } nets = { 'generator_net': models.bcqGenerator, 'perturbator_net': models.bcqPerturbator, 'target_perturbator_net': models.bcqPerturbator, 'value_net1': models.Critic, 'target_value_net1': models.Critic, 'value_net2': models.Critic, 'target_value_net2': models.Critic, } optimizer = { 'generator_optimizer': some optimizer 'policy_optimizer': some optimizer 'value_optimizer1': some optimizer 'value_optimizer2': some optimizer } """ if debug is None: debug = dict() state, action, reward, next_state, done = data.get_base_batch( batch, device=device) batch_size = done.size(0) # --------------------------------------------------------# # Variational Auto-Encoder Learning recon, mean, std = nets['generator_net'](state, action) recon_loss = F.mse_loss(recon, action) KL_loss = -0.5 * (1 + torch.log(std.pow(2)) - mean.pow(2) - std.pow(2)).mean() generator_loss = recon_loss + 0.5 * KL_loss if not learn: writer.add_histogram('generator_mean', mean, step) writer.add_histogram('generator_std', std, step) debug['recon'] = recon writer.add_figure('reconstructed', utils.pairwise_distances_fig(recon[:50]), step) if learn: optimizer['generator_optimizer'].zero_grad() generator_loss.backward() optimizer['generator_optimizer'].step() # --------------------------------------------------------# # Value Learning with torch.no_grad(): # p.s. repeat_interleave was added in torch 1.1 # if an error pops up, run 'conda update pytorch' state_rep = torch.repeat_interleave(next_state, params['n_generator_samples'], 0) sampled_action = nets['generator_net'].decode(state_rep) perturbed_action = nets['target_perturbator_net'](state_rep, sampled_action) target_Q1 = nets['target_value_net1'](state_rep, perturbed_action) target_Q2 = nets['target_value_net1'](state_rep, perturbed_action) target_value = 0.75 * torch.min(target_Q1, target_Q2) # value soft update target_value += 0.25 * torch.max(target_Q1, target_Q2) # target_value = target_value.view(batch_size, -1).max(1)[0].view(-1, 1) expected_value = temporal_difference(reward, done, params['gamma'], target_value) value = nets['value_net1'](state, action) value_loss = torch.pow(value - expected_value.detach(), 2).mean() if learn: optimizer['value_optimizer1'].zero_grad() optimizer['value_optimizer2'].zero_grad() value_loss.backward() optimizer['value_optimizer1'].step() optimizer['value_optimizer2'].step() else: writer.add_histogram('value', value, step) writer.add_histogram('target_value', target_value, step) writer.add_histogram('expected_value', expected_value, step) writer.close() # --------------------------------------------------------# # Perturbator learning sampled_actions = nets['generator_net'].decode(state) perturbed_actions = nets['perturbator_net'](state, sampled_actions) perturbator_loss = -nets['value_net1'](state, perturbed_actions) if not learn: writer.add_histogram('perturbator_loss', perturbator_loss, step) perturbator_loss = perturbator_loss.mean() if learn: if step % params['perturbator_step']: optimizer['perturbator_optimizer'].zero_grad() perturbator_loss.backward() torch.nn.utils.clip_grad_norm_( nets['perturbator_net'].parameters(), -1, 1) optimizer['perturbator_optimizer'].step() soft_update(nets['value_net1'], nets['target_value_net1'], soft_tau=params['soft_tau']) soft_update(nets['value_net2'], nets['target_value_net2'], soft_tau=params['soft_tau']) soft_update(nets['perturbator_net'], nets['target_perturbator_net'], soft_tau=params['soft_tau']) else: debug['sampled_actions'] = sampled_actions debug['perturbed_actions'] = perturbed_actions writer.add_figure('sampled_actions', utils.pairwise_distances_fig(sampled_actions[:50]), step) writer.add_figure('perturbed_actions', utils.pairwise_distances_fig(perturbed_actions[:50]), step) # --------------------------------------------------------# losses = { 'value': value_loss.item(), 'perturbator': perturbator_loss.item(), 'generator': generator_loss.item(), 'step': step } utils.write_losses(writer, losses, kind='train' if learn else 'test') return losses
def test_model(model, test_loader, epoch, now, batch_idx, test_len=100): if not os.path.exists(EVAL_DIR + now): os.makedirs(EVAL_DIR + now) model.eval() model.train_stat = False test_losses = [] diffs_avg = [] with torch.no_grad(): for i, (name, x, (y_l, y_ab)) in enumerate(test_loader): x = x.to(device) y_l = y_l.to(device) y_ab = y_ab.to(device) inputs = x.cpu() # print('Inputs shape ', inputs.shape) out_ab = model(x) loss = F.mse_loss(out_ab, y_ab) print('Test batch %d Loss %.4f' % (i, loss.item())) test_losses.append(loss.item()) # print('got output') # print('outputs shape', output.shape) out_ab = out_ab.permute(0, 2, 3, 1) output = out_ab.cpu().numpy() j = random.randint(0, len(output) - 1) a_channel = output[j][:, :, 0] b_channel = output[j][:, :, 1] image = inputs[j].squeeze(0).numpy() # print(image.shape, a_channel.shape, b_channel.shape) actual_color_image = plt.imread(args.data_path + COLOR_DIR + name[j]) true_ab = cv2.cvtColor(actual_color_image, cv2.COLOR_RGB2LAB)[:, :, 1:] np_image = np.dstack((image, a_channel, b_channel)) np_rgb = cv2.cvtColor(np_image, cv2.COLOR_LAB2RGB) diffs = color_diff(np.dstack((a_channel, b_channel)), true_ab, image) # import ipdb; ipdb.set_trace() file_name = EVAL_DIR + now + '/' + 'cimg_' + str( epoch) + '_' + str(batch_idx) + '_' + str(j) + '_' + name[j] # exit() grid = make_grid(actual_color_image, np_rgb, image) summary_writer.add_image( "test image/" + 'cimg_' + str(epoch) + '_' + str(batch_idx) + '_' + str(j) + '_' + name[j], grid) imsave(file_name, np_rgb) cv2.imwrite(file_name.split('.png')[0] + '_mri.png', image) diffs_avg.append(np.average(diffs)) with open(EVAL_DIR + now + '/order.txt', 'a') as f: val = "%d, %s\n" % (epoch, 'cimg_' + str(epoch) + '_' + name[j]) f.writelines(val) if i % 100 == 0 and i != 0: break if args.test_mode: break if args.test_mode: break summary_writer.add_scalar('lab difference', np.average(diffs_avg)) return test_losses