def forward(self, img_inputs, captions, target_ingrs, sample=False, keep_cnn_gradients=False): if sample: return self.sample(img_inputs, greedy=True) targets = captions[:, 1:] targets = targets.contiguous().view(-1) img_features = self.image_encoder(img_inputs, keep_cnn_gradients) losses = {} target_one_hot = label2onehot(target_ingrs, self.pad_value) target_one_hot_smooth = label2onehot(target_ingrs, self.pad_value) # ingredient prediction if not self.recipe_only: target_one_hot_smooth[target_one_hot_smooth == 1] = ( 1 - self.label_smoothing) target_one_hot_smooth[ target_one_hot_smooth == 0] = self.label_smoothing / target_one_hot_smooth.size(-1) # decode ingredients with transformer # autoregressive mode for ingredient decoder ingr_ids, ingr_logits = self.ingredient_decoder.sample( None, None, greedy=True, temperature=1.0, img_features=img_features, first_token_value=0, replacement=False) ingr_logits = torch.nn.functional.softmax(ingr_logits, dim=-1) # find idxs for eos ingredient # eos probability is the one assigned to the first position of the softmax eos = ingr_logits[:, :, 0] target_eos = ((target_ingrs == 0) ^ (target_ingrs == self.pad_value)) eos_pos = (target_ingrs == 0) eos_head = ((target_ingrs != self.pad_value) & (target_ingrs != 0)) # select transformer steps to pool from mask_perminv = mask_from_eos(target_ingrs, eos_value=0, mult_before=False) ingr_probs = ingr_logits * mask_perminv.float().unsqueeze(-1) ingr_probs, _ = torch.max(ingr_probs, dim=1) # ignore predicted ingredients after eos in ground truth ingr_ids[mask_perminv == 0] = self.pad_value ingr_loss = self.crit_ingr(ingr_probs, target_one_hot_smooth) ingr_loss = torch.mean(ingr_loss, dim=-1) losses['ingr_loss'] = ingr_loss # cardinality penalty losses['card_penalty'] = torch.abs((ingr_probs*target_one_hot).sum(1) - target_one_hot.sum(1)) + \ torch.abs((ingr_probs*(1-target_one_hot)).sum(1)) eos_loss = self.crit_eos(eos, target_eos.float()) mult = 1 / 2 # eos loss is only computed for timesteps <= t_eos and equally penalizes 0s and 1s losses['eos_loss'] = mult*(eos_loss * eos_pos.float()).sum(1) / (eos_pos.float().sum(1) + 1e-6) + \ mult*(eos_loss * eos_head.float()).sum(1) / (eos_head.float().sum(1) + 1e-6) # iou pred_one_hot = label2onehot(ingr_ids, self.pad_value) # iou sample during training is computed using the true eos position losses['iou'] = softIoU(pred_one_hot, target_one_hot) if self.ingrs_only: return losses # encode ingredients target_ingr_feats = self.ingredient_encoder(target_ingrs) target_ingr_mask = mask_from_eos(target_ingrs, eos_value=0, mult_before=False) target_ingr_mask = target_ingr_mask.float().unsqueeze(1) outputs, ids = self.recipe_decoder(target_ingr_feats, target_ingr_mask, captions, img_features) outputs = outputs[:, :-1, :].contiguous() outputs = outputs.view(outputs.size(0) * outputs.size(1), -1) loss = self.crit(outputs, targets) losses['recipe_loss'] = loss return losses
def main(args): where_to_save = os.path.join(args.save_dir, args.project_name, args.model_name) checkpoints_dir = os.path.join(where_to_save, 'checkpoints') logs_dir = os.path.join(where_to_save, 'logs') if not args.log_term: print("Eval logs will be saved to:", os.path.join(logs_dir, 'eval.log')) sys.stdout = open(os.path.join(logs_dir, 'eval.log'), 'w') sys.stderr = open(os.path.join(logs_dir, 'eval.err'), 'w') vars_to_replace = [ 'greedy', 'recipe_only', 'ingrs_only', 'temperature', 'batch_size', 'maxseqlen', 'get_perplexity', 'use_true_ingrs', 'eval_split', 'save_dir', 'aux_data_dir', 'recipe1m_dir', 'project_name', 'use_lmdb', 'beam' ] store_dict = {} for var in vars_to_replace: store_dict[var] = getattr(args, var) #args = pickle.load(open(os.path.join(checkpoints_dir, 'args.pkl'), 'rb')) for var in vars_to_replace: setattr(args, var, store_dict[var]) print(args) transforms_list = [] transforms_list.append(transforms.Resize((args.crop_size))) transforms_list.append(transforms.CenterCrop(args.crop_size)) transforms_list.append(transforms.ToTensor()) transforms_list.append( transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))) # Image preprocessing transform = transforms.Compose(transforms_list) # data loader data_dir = args.recipe1m_dir data_loader, dataset = get_loader(data_dir, args.aux_data_dir, args.eval_split, args.maxseqlen, args.maxnuminstrs, args.maxnumlabels, args.maxnumims, transform, args.batch_size, shuffle=False, num_workers=args.num_workers, drop_last=False, max_num_samples=-1, use_lmdb=args.use_lmdb, suff=args.suff) ingr_vocab_size = dataset.get_ingrs_vocab_size() instrs_vocab_size = dataset.get_instrs_vocab_size() args.numgens = 1 # Build the model model = get_model(args, ingr_vocab_size, instrs_vocab_size) model_path = os.path.join(args.save_dir, args.project_name, args.model_name, 'checkpoints', 'modelbest.ckpt') # overwrite flags for inference model.recipe_only = args.recipe_only model.ingrs_only = args.ingrs_only # Load the trained model parameters model.load_state_dict(torch.load(model_path, map_location=map_loc)) model.eval() model = model.to(device) results_dict = {'recipes': {}, 'ingrs': {}, 'ingr_iou': {}} captions = {} iou = [] error_types = { 'tp_i': 0, 'fp_i': 0, 'fn_i': 0, 'tn_i': 0, 'tp_all': 0, 'fp_all': 0, 'fn_all': 0 } perplexity_list = [] n_rep, th = 0, 0.3 for i, (img_inputs, true_caps_batch, ingr_gt, imgid, impath) in tqdm(enumerate(data_loader)): ingr_gt = ingr_gt.to(device) true_caps_batch = true_caps_batch.to(device) true_caps_shift = true_caps_batch.clone()[:, 1:].contiguous() img_inputs = img_inputs.to(device) true_ingrs = ingr_gt if args.use_true_ingrs else None for gens in range(args.numgens): with torch.no_grad(): if args.get_perplexity: losses = model(img_inputs, true_caps_batch, ingr_gt, keep_cnn_gradients=False) recipe_loss = losses['recipe_loss'] recipe_loss = recipe_loss.view(true_caps_shift.size()) non_pad_mask = true_caps_shift.ne(instrs_vocab_size - 1).float() recipe_loss = torch.sum(recipe_loss * non_pad_mask, dim=-1) / torch.sum(non_pad_mask, dim=-1) perplexity = torch.exp(recipe_loss) perplexity = perplexity.detach().cpu().numpy().tolist() perplexity_list.extend(perplexity) else: outputs = model.sample(img_inputs, args.greedy, args.temperature, args.beam, true_ingrs) if not args.recipe_only: fake_ingrs = outputs['ingr_ids'] pred_one_hot = label2onehot(fake_ingrs, ingr_vocab_size - 1) target_one_hot = label2onehot(ingr_gt, ingr_vocab_size - 1) iou_item = torch.mean( softIoU(pred_one_hot, target_one_hot)).item() iou.append(iou_item) update_error_types(error_types, pred_one_hot, target_one_hot) fake_ingrs = fake_ingrs.detach().cpu().numpy() for ingr_idx, fake_ingr in enumerate(fake_ingrs): iou_item = softIoU( pred_one_hot[ingr_idx].unsqueeze(0), target_one_hot[ingr_idx].unsqueeze(0)).item() results_dict['ingrs'][imgid[ingr_idx]] = [] results_dict['ingrs'][imgid[ingr_idx]].append( fake_ingr) results_dict['ingr_iou'][ imgid[ingr_idx]] = iou_item if not args.ingrs_only: sampled_ids_batch = outputs['recipe_ids'] sampled_ids_batch = sampled_ids_batch.cpu().detach( ).numpy() for j, sampled_ids in enumerate(sampled_ids_batch): score = compute_score(sampled_ids) if score < th: n_rep += 1 if imgid[j] not in captions.keys(): results_dict['recipes'][imgid[j]] = [] results_dict['recipes'][imgid[j]].append( sampled_ids) if args.get_perplexity: print(len(perplexity_list)) print(np.mean(perplexity_list)) else: if not args.recipe_only: ret_metrics = { 'accuracy': [], 'f1': [], 'jaccard': [], 'f1_ingredients': [] } compute_metrics(ret_metrics, error_types, ['accuracy', 'f1', 'jaccard', 'f1_ingredients'], eps=1e-10, weights=None) for k, v in ret_metrics.items(): print(k, np.mean(v)) if args.greedy: suff = 'greedy' else: if args.beam != -1: suff = 'beam_' + str(args.beam) else: suff = 'temp_' + str(args.temperature) results_file = os.path.join( args.save_dir, args.project_name, args.model_name, 'checkpoints', args.eval_split + '_' + suff + '_gencaps.pkl') print(results_file) pickle.dump(results_dict, open(results_file, 'wb')) print("Number of samples with excessive repetitions:", n_rep)
def main(args): # Create model directory & other aux folders for logging where_to_save = os.path.join(args.save_dir, args.project_name, args.model_name) checkpoints_dir = os.path.join(where_to_save, 'checkpoints') logs_dir = os.path.join(where_to_save, 'logs') tb_logs = os.path.join(args.save_dir, args.project_name, 'tb_logs', args.model_name) make_dir(where_to_save) make_dir(logs_dir) make_dir(checkpoints_dir) make_dir(tb_logs) if args.tensorboard: logger = Visualizer(tb_logs, name='visual_results') # check if we want to resume from last checkpoint of current model if args.resume: args = pickle.load( open(os.path.join(checkpoints_dir, 'args.pkl'), 'rb')) args.resume = True # logs to disk if not args.log_term: print("Training logs will be saved to:", os.path.join(logs_dir, 'train.log')) sys.stdout = open(os.path.join(logs_dir, 'train.log'), 'w') sys.stderr = open(os.path.join(logs_dir, 'train.err'), 'w') print(args) pickle.dump(args, open(os.path.join(checkpoints_dir, 'args.pkl'), 'wb')) # patience init curr_pat = 0 # Build data loader data_loaders = {} datasets = {} data_dir = args.recipe1m_dir for split in ['train', 'val']: transforms_list = [transforms.Resize((args.image_size))] if split == 'train': # Image preprocessing, normalization for the pretrained resnet transforms_list.append(transforms.RandomHorizontalFlip()) transforms_list.append( transforms.RandomAffine(degrees=10, translate=(0.1, 0.1))) transforms_list.append(transforms.RandomCrop(args.crop_size)) else: transforms_list.append(transforms.CenterCrop(args.crop_size)) transforms_list.append(transforms.ToTensor()) transforms_list.append( transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))) transform = transforms.Compose(transforms_list) max_num_samples = max(args.max_eval, args.batch_size) if split == 'val' else -1 data_loaders[split], datasets[split] = get_loader( data_dir, args.aux_data_dir, split, args.maxseqlen, args.maxnuminstrs, args.maxnumlabels, args.maxnumims, transform, args.batch_size, shuffle=split == 'train', num_workers=args.num_workers, drop_last=True, max_num_samples=max_num_samples, use_lmdb=args.use_lmdb, suff=args.suff) ingr_vocab_size = datasets[split].get_ingrs_vocab_size() instrs_vocab_size = datasets[split].get_instrs_vocab_size() # Build the model model = get_model(args, ingr_vocab_size, instrs_vocab_size) keep_cnn_gradients = False decay_factor = 1.0 # add model parameters if args.ingrs_only: params = list(model.ingredient_decoder.parameters()) + list( model.ingredient_encoder.parameters()) elif args.recipe_only: params = list(model.recipe_decoder.parameters()) + list( model.ingredient_encoder.parameters()) else: params = list(model.recipe_decoder.parameters()) + list(model.ingredient_decoder.parameters()) \ + list(model.ingredient_encoder.parameters()) # only train the linear layer in the encoder if we are not transfering from another model if args.transfer_from == '': params += list(model.image_encoder.linear.parameters()) params_cnn = list(model.image_encoder.resnet.parameters()) print("CNN params:", sum(p.numel() for p in params_cnn if p.requires_grad)) print("decoder params:", sum(p.numel() for p in params if p.requires_grad)) # start optimizing cnn from the beginning if params_cnn is not None and args.finetune_after == 0: optimizer = torch.optim.Adam( [{ 'params': params }, { 'params': params_cnn, 'lr': args.learning_rate * args.scale_learning_rate_cnn }], lr=args.learning_rate, weight_decay=args.weight_decay) keep_cnn_gradients = True print("Fine tuning resnet") else: optimizer = torch.optim.Adam(params, lr=args.learning_rate) if args.resume: model_path = os.path.join(args.save_dir, args.project_name, args.model_name, 'checkpoints', 'model.ckpt') optim_path = os.path.join(args.save_dir, args.project_name, args.model_name, 'checkpoints', 'optim.ckpt') optimizer.load_state_dict(torch.load(optim_path, map_location=map_loc)) for state in optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.to(device) model.load_state_dict(torch.load(model_path, map_location=map_loc)) if args.transfer_from != '': # loads CNN encoder from transfer_from model model_path = os.path.join(args.save_dir, args.project_name, args.transfer_from, 'checkpoints', 'modelbest.ckpt') pretrained_dict = torch.load(model_path, map_location=map_loc) pretrained_dict = { k: v for k, v in pretrained_dict.items() if 'encoder' in k } model.load_state_dict(pretrained_dict, strict=False) args, model = merge_models(args, model, ingr_vocab_size, instrs_vocab_size) if device != 'cpu' and torch.cuda.device_count() > 1: model = nn.DataParallel(model) model = model.to(device) cudnn.benchmark = True if not hasattr(args, 'current_epoch'): args.current_epoch = 0 es_best = 10000 if args.es_metric == 'loss' else 0 # Train the model start = args.current_epoch for epoch in range(start, args.num_epochs): # save current epoch for resuming if args.tensorboard: logger.reset() args.current_epoch = epoch # increase / decrase values for moving params if args.decay_lr: frac = epoch // args.lr_decay_every decay_factor = args.lr_decay_rate**frac new_lr = args.learning_rate * decay_factor print('Epoch %d. lr: %.5f' % (epoch, new_lr)) set_lr(optimizer, decay_factor) if args.finetune_after != -1 and args.finetune_after < epoch \ and not keep_cnn_gradients and params_cnn is not None: print("Starting to fine tune CNN") # start with learning rates as they were (if decayed during training) optimizer = torch.optim.Adam([{ 'params': params }, { 'params': params_cnn, 'lr': decay_factor * args.learning_rate * args.scale_learning_rate_cnn }], lr=decay_factor * args.learning_rate) keep_cnn_gradients = True for split in ['train', 'val']: if split == 'train': model.train() else: model.eval() total_step = len(data_loaders[split]) loader = iter(data_loaders[split]) total_loss_dict = { 'recipe_loss': [], 'ingr_loss': [], 'eos_loss': [], 'loss': [], 'iou': [], 'perplexity': [], 'iou_sample': [], 'f1': [], 'card_penalty': [] } error_types = { 'tp_i': 0, 'fp_i': 0, 'fn_i': 0, 'tn_i': 0, 'tp_all': 0, 'fp_all': 0, 'fn_all': 0 } torch.cuda.synchronize() start = time.time() for i in range(total_step): img_inputs, captions, ingr_gt, img_ids, paths = loader.next() ingr_gt = ingr_gt.to(device) img_inputs = img_inputs.to(device) captions = captions.to(device) true_caps_batch = captions.clone()[:, 1:].contiguous() loss_dict = {} if split == 'val': with torch.no_grad(): losses = model(img_inputs, captions, ingr_gt) if not args.recipe_only: outputs = model(img_inputs, captions, ingr_gt, sample=True) ingr_ids_greedy = outputs['ingr_ids'] mask = mask_from_eos(ingr_ids_greedy, eos_value=0, mult_before=False) ingr_ids_greedy[mask == 0] = ingr_vocab_size - 1 pred_one_hot = label2onehot( ingr_ids_greedy, ingr_vocab_size - 1) target_one_hot = label2onehot( ingr_gt, ingr_vocab_size - 1) iou_sample = softIoU(pred_one_hot, target_one_hot) iou_sample = iou_sample.sum() / ( torch.nonzero(iou_sample.data).size(0) + 1e-6) loss_dict['iou_sample'] = iou_sample.item() update_error_types(error_types, pred_one_hot, target_one_hot) del outputs, pred_one_hot, target_one_hot, iou_sample else: losses = model(img_inputs, captions, ingr_gt, keep_cnn_gradients=keep_cnn_gradients) if not args.ingrs_only: recipe_loss = losses['recipe_loss'] recipe_loss = recipe_loss.view(true_caps_batch.size()) non_pad_mask = true_caps_batch.ne(instrs_vocab_size - 1).float() recipe_loss = torch.sum(recipe_loss * non_pad_mask, dim=-1) / torch.sum(non_pad_mask, dim=-1) perplexity = torch.exp(recipe_loss) recipe_loss = recipe_loss.mean() perplexity = perplexity.mean() loss_dict['recipe_loss'] = recipe_loss.item() loss_dict['perplexity'] = perplexity.item() else: recipe_loss = 0 if not args.recipe_only: ingr_loss = losses['ingr_loss'] ingr_loss = ingr_loss.mean() loss_dict['ingr_loss'] = ingr_loss.item() eos_loss = losses['eos_loss'] eos_loss = eos_loss.mean() loss_dict['eos_loss'] = eos_loss.item() iou_seq = losses['iou'] iou_seq = iou_seq.mean() loss_dict['iou'] = iou_seq.item() card_penalty = losses['card_penalty'].mean() loss_dict['card_penalty'] = card_penalty.item() else: ingr_loss, eos_loss, card_penalty = 0, 0, 0 loss = args.loss_weight[0] * recipe_loss + args.loss_weight[1] * ingr_loss \ + args.loss_weight[2]*eos_loss + args.loss_weight[3]*card_penalty loss_dict['loss'] = loss.item() for key in loss_dict.keys(): total_loss_dict[key].append(loss_dict[key]) if split == 'train': model.zero_grad() loss.backward() optimizer.step() # Print log info if args.log_step != -1 and i % args.log_step == 0: elapsed_time = time.time() - start lossesstr = "" for k in total_loss_dict.keys(): if len(total_loss_dict[k]) == 0: continue this_one = "%s: %.4f" % ( k, np.mean(total_loss_dict[k][-args.log_step:])) lossesstr += this_one + ', ' # this only displays nll loss on captions, the rest of losses will be in tensorboard logs strtoprint = 'Split: %s, Epoch [%d/%d], Step [%d/%d], Losses: %sTime: %.4f' % ( split, epoch, args.num_epochs, i, total_step, lossesstr, elapsed_time) print(strtoprint) if args.tensorboard: # logger.histo_summary(model=model, step=total_step * epoch + i) logger.scalar_summary( mode=split + '_iter', epoch=total_step * epoch + i, **{ k: np.mean(v[-args.log_step:]) for k, v in total_loss_dict.items() if v }) torch.cuda.synchronize() start = time.time() del loss, losses, captions, img_inputs if split == 'val' and not args.recipe_only: ret_metrics = { 'accuracy': [], 'f1': [], 'jaccard': [], 'f1_ingredients': [], 'dice': [] } compute_metrics( ret_metrics, error_types, ['accuracy', 'f1', 'jaccard', 'f1_ingredients', 'dice'], eps=1e-10, weights=None) total_loss_dict['f1'] = ret_metrics['f1'] if args.tensorboard: # 1. Log scalar values (scalar summary) logger.scalar_summary( mode=split, epoch=epoch, **{k: np.mean(v) for k, v in total_loss_dict.items() if v}) # Save the model's best checkpoint if performance was improved es_value = np.mean(total_loss_dict[args.es_metric]) # save current model as well save_model(model, optimizer, checkpoints_dir, suff='') if (args.es_metric == 'loss' and es_value < es_best) or (args.es_metric == 'iou_sample' and es_value > es_best): es_best = es_value save_model(model, optimizer, checkpoints_dir, suff='best') pickle.dump(args, open(os.path.join(checkpoints_dir, 'args.pkl'), 'wb')) curr_pat = 0 print('Saved checkpoint.') else: curr_pat += 1 if curr_pat > args.patience: break if args.tensorboard: logger.close()
def forward(self, img_inputs, target_ingrs,target_action, sample=False, keep_cnn_gradients=False): if sample: return self.sample(img_inputs, greedy=True) img_features = self.image_encoder(img_inputs, keep_cnn_gradients) losses = {} ####################################################################### #这一部分是ingredient生成one-hot target_one_hot_ingrs = label2onehot(target_ingrs, self.pad_value_ingrs) target_one_hot_smooth_ingrs = label2onehot(target_ingrs, self.pad_value_ingrs) #label_smooth target_one_hot_smooth_ingrs[target_one_hot_smooth_ingrs == 1] = (1-self.label_smoothing) target_one_hot_smooth_ingrs[target_one_hot_smooth_ingrs == 0] = self.label_smoothing / target_one_hot_smooth_ingrs.size(-1) ingr_ids, ingr_logits = self.ingredient_decoder.sample(None, None, greedy=True, temperature=1.0, img_features=img_features, first_token_value=0, replacement=False) ingr_logits = torch.nn.functional.softmax(ingr_logits, dim=-1) ############################ #这一部分是ingredient_eos_loss的计算 ingr_eos = ingr_logits[:, :, 0] target_ingr_eos = ((target_ingrs == 0) ^ (target_ingrs == self.pad_value_ingrs)) target_ingr_eos=target_ingr_eos.float() ingr_eos_loss=self.crit(ingr_eos,target_ingr_eos) ingr_eos_loss = torch.mean(ingr_eos_loss, dim=-1) losses['ingr_eos_loss']=ingr_eos_loss ######################################### # 这一部分是ingredient_loss的计算 mask_perminv_ingrs = mask_from_eos(target_ingrs, eos_value=0, mult_before=False) ingr_probs = ingr_logits * mask_perminv_ingrs.float().unsqueeze(-1) ingr_probs, _ = torch.max(ingr_probs, dim=1) ingr_ids[mask_perminv_ingrs == 0] = self.pad_value_ingrs ingr_loss = self.crit(ingr_probs, target_one_hot_ingrs) ingr_loss = torch.mean(ingr_loss, dim=-1) losses['ingr_loss'] = ingr_loss # iou pred_one_hot_ingrs = label2onehot(ingr_ids, self.pad_value_ingrs) losses['ingr_iou'] = softIoU(pred_one_hot_ingrs, target_one_hot_ingrs) ################################################################################ # #这一部分是action生成one-hot target_one_hot_action = label2onehot(target_action, self.pad_value_action) target_one_hot_smooth_action = label2onehot(target_action, self.pad_value_action) target_one_hot_smooth_action[target_one_hot_smooth_action == 1] = (1 - self.label_smoothing) target_one_hot_smooth_action[target_one_hot_smooth_action == 0] = self.label_smoothing / target_one_hot_smooth_action.size(-1) action_ids, action_logits = self.action_decoder.sample(None, None, greedy=True, temperature=1.0, img_features=img_features, first_token_value=0, replacement=False) action_logits = torch.nn.functional.softmax(action_logits, dim=-1) ############################ # 这一部分是action_eos_loss的计算 action_eos = action_logits[:, :, 0] target_action_eos = ((target_action == 0) ^ (target_action == self.pad_value_action)) target_action_eos=target_action_eos.float() action_eos_loss = self.crit(action_eos, target_action_eos) action_eos_loss = torch.mean(action_eos_loss, dim=-1) losses['action_eos_loss']= action_eos_loss ######################################## # 这一部分是ingredient_loss的计算 mask_perminv_action = mask_from_eos(target_action, eos_value=0, mult_before=False) action_probs = action_logits * mask_perminv_action.float().unsqueeze(-1) action_probs, _ = torch.max(action_probs, dim=1) action_ids[mask_perminv_action == 0] = self.pad_value_action action_loss = self.crit(action_probs, target_one_hot_action) action_loss = torch.mean(action_loss, dim=-1) losses['action_loss'] = action_loss # iou pred_one_hot_action = label2onehot(action_ids, self.pad_value_action) losses['action_iou'] = softIoU(pred_one_hot_action, target_one_hot_action) return losses