Beispiel #1
0
    def forward(self,
                img_inputs,
                captions,
                target_ingrs,
                sample=False,
                keep_cnn_gradients=False):

        if sample:
            return self.sample(img_inputs, greedy=True)

        targets = captions[:, 1:]
        targets = targets.contiguous().view(-1)

        img_features = self.image_encoder(img_inputs, keep_cnn_gradients)

        losses = {}
        target_one_hot = label2onehot(target_ingrs, self.pad_value)
        target_one_hot_smooth = label2onehot(target_ingrs, self.pad_value)

        # ingredient prediction
        if not self.recipe_only:
            target_one_hot_smooth[target_one_hot_smooth == 1] = (
                1 - self.label_smoothing)
            target_one_hot_smooth[
                target_one_hot_smooth ==
                0] = self.label_smoothing / target_one_hot_smooth.size(-1)

            # decode ingredients with transformer
            # autoregressive mode for ingredient decoder
            ingr_ids, ingr_logits = self.ingredient_decoder.sample(
                None,
                None,
                greedy=True,
                temperature=1.0,
                img_features=img_features,
                first_token_value=0,
                replacement=False)

            ingr_logits = torch.nn.functional.softmax(ingr_logits, dim=-1)

            # find idxs for eos ingredient
            # eos probability is the one assigned to the first position of the softmax
            eos = ingr_logits[:, :, 0]
            target_eos = ((target_ingrs == 0) ^
                          (target_ingrs == self.pad_value))

            eos_pos = (target_ingrs == 0)
            eos_head = ((target_ingrs != self.pad_value) & (target_ingrs != 0))

            # select transformer steps to pool from
            mask_perminv = mask_from_eos(target_ingrs,
                                         eos_value=0,
                                         mult_before=False)
            ingr_probs = ingr_logits * mask_perminv.float().unsqueeze(-1)

            ingr_probs, _ = torch.max(ingr_probs, dim=1)

            # ignore predicted ingredients after eos in ground truth
            ingr_ids[mask_perminv == 0] = self.pad_value

            ingr_loss = self.crit_ingr(ingr_probs, target_one_hot_smooth)
            ingr_loss = torch.mean(ingr_loss, dim=-1)

            losses['ingr_loss'] = ingr_loss

            # cardinality penalty
            losses['card_penalty'] = torch.abs((ingr_probs*target_one_hot).sum(1) - target_one_hot.sum(1)) + \
                                     torch.abs((ingr_probs*(1-target_one_hot)).sum(1))

            eos_loss = self.crit_eos(eos, target_eos.float())

            mult = 1 / 2
            # eos loss is only computed for timesteps <= t_eos and equally penalizes 0s and 1s
            losses['eos_loss'] = mult*(eos_loss * eos_pos.float()).sum(1) / (eos_pos.float().sum(1) + 1e-6) + \
                                 mult*(eos_loss * eos_head.float()).sum(1) / (eos_head.float().sum(1) + 1e-6)
            # iou
            pred_one_hot = label2onehot(ingr_ids, self.pad_value)
            # iou sample during training is computed using the true eos position
            losses['iou'] = softIoU(pred_one_hot, target_one_hot)

        if self.ingrs_only:
            return losses

        # encode ingredients
        target_ingr_feats = self.ingredient_encoder(target_ingrs)
        target_ingr_mask = mask_from_eos(target_ingrs,
                                         eos_value=0,
                                         mult_before=False)

        target_ingr_mask = target_ingr_mask.float().unsqueeze(1)

        outputs, ids = self.recipe_decoder(target_ingr_feats, target_ingr_mask,
                                           captions, img_features)

        outputs = outputs[:, :-1, :].contiguous()
        outputs = outputs.view(outputs.size(0) * outputs.size(1), -1)

        loss = self.crit(outputs, targets)

        losses['recipe_loss'] = loss

        return losses
Beispiel #2
0
def main(args):

    where_to_save = os.path.join(args.save_dir, args.project_name,
                                 args.model_name)
    checkpoints_dir = os.path.join(where_to_save, 'checkpoints')
    logs_dir = os.path.join(where_to_save, 'logs')

    if not args.log_term:
        print("Eval logs will be saved to:",
              os.path.join(logs_dir, 'eval.log'))
        sys.stdout = open(os.path.join(logs_dir, 'eval.log'), 'w')
        sys.stderr = open(os.path.join(logs_dir, 'eval.err'), 'w')

    vars_to_replace = [
        'greedy', 'recipe_only', 'ingrs_only', 'temperature', 'batch_size',
        'maxseqlen', 'get_perplexity', 'use_true_ingrs', 'eval_split',
        'save_dir', 'aux_data_dir', 'recipe1m_dir', 'project_name', 'use_lmdb',
        'beam'
    ]
    store_dict = {}
    for var in vars_to_replace:
        store_dict[var] = getattr(args, var)
    #args = pickle.load(open(os.path.join(checkpoints_dir, 'args.pkl'), 'rb'))
    for var in vars_to_replace:
        setattr(args, var, store_dict[var])
    print(args)

    transforms_list = []
    transforms_list.append(transforms.Resize((args.crop_size)))
    transforms_list.append(transforms.CenterCrop(args.crop_size))
    transforms_list.append(transforms.ToTensor())
    transforms_list.append(
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)))
    # Image preprocessing
    transform = transforms.Compose(transforms_list)

    # data loader
    data_dir = args.recipe1m_dir
    data_loader, dataset = get_loader(data_dir,
                                      args.aux_data_dir,
                                      args.eval_split,
                                      args.maxseqlen,
                                      args.maxnuminstrs,
                                      args.maxnumlabels,
                                      args.maxnumims,
                                      transform,
                                      args.batch_size,
                                      shuffle=False,
                                      num_workers=args.num_workers,
                                      drop_last=False,
                                      max_num_samples=-1,
                                      use_lmdb=args.use_lmdb,
                                      suff=args.suff)

    ingr_vocab_size = dataset.get_ingrs_vocab_size()
    instrs_vocab_size = dataset.get_instrs_vocab_size()

    args.numgens = 1

    # Build the model
    model = get_model(args, ingr_vocab_size, instrs_vocab_size)
    model_path = os.path.join(args.save_dir, args.project_name,
                              args.model_name, 'checkpoints', 'modelbest.ckpt')

    # overwrite flags for inference
    model.recipe_only = args.recipe_only
    model.ingrs_only = args.ingrs_only

    # Load the trained model parameters
    model.load_state_dict(torch.load(model_path, map_location=map_loc))

    model.eval()
    model = model.to(device)
    results_dict = {'recipes': {}, 'ingrs': {}, 'ingr_iou': {}}
    captions = {}
    iou = []
    error_types = {
        'tp_i': 0,
        'fp_i': 0,
        'fn_i': 0,
        'tn_i': 0,
        'tp_all': 0,
        'fp_all': 0,
        'fn_all': 0
    }
    perplexity_list = []
    n_rep, th = 0, 0.3

    for i, (img_inputs, true_caps_batch, ingr_gt, imgid,
            impath) in tqdm(enumerate(data_loader)):

        ingr_gt = ingr_gt.to(device)
        true_caps_batch = true_caps_batch.to(device)

        true_caps_shift = true_caps_batch.clone()[:, 1:].contiguous()
        img_inputs = img_inputs.to(device)

        true_ingrs = ingr_gt if args.use_true_ingrs else None
        for gens in range(args.numgens):
            with torch.no_grad():

                if args.get_perplexity:

                    losses = model(img_inputs,
                                   true_caps_batch,
                                   ingr_gt,
                                   keep_cnn_gradients=False)
                    recipe_loss = losses['recipe_loss']
                    recipe_loss = recipe_loss.view(true_caps_shift.size())
                    non_pad_mask = true_caps_shift.ne(instrs_vocab_size -
                                                      1).float()
                    recipe_loss = torch.sum(recipe_loss * non_pad_mask,
                                            dim=-1) / torch.sum(non_pad_mask,
                                                                dim=-1)
                    perplexity = torch.exp(recipe_loss)

                    perplexity = perplexity.detach().cpu().numpy().tolist()
                    perplexity_list.extend(perplexity)

                else:

                    outputs = model.sample(img_inputs, args.greedy,
                                           args.temperature, args.beam,
                                           true_ingrs)

                    if not args.recipe_only:
                        fake_ingrs = outputs['ingr_ids']
                        pred_one_hot = label2onehot(fake_ingrs,
                                                    ingr_vocab_size - 1)
                        target_one_hot = label2onehot(ingr_gt,
                                                      ingr_vocab_size - 1)
                        iou_item = torch.mean(
                            softIoU(pred_one_hot, target_one_hot)).item()
                        iou.append(iou_item)

                        update_error_types(error_types, pred_one_hot,
                                           target_one_hot)

                        fake_ingrs = fake_ingrs.detach().cpu().numpy()

                        for ingr_idx, fake_ingr in enumerate(fake_ingrs):

                            iou_item = softIoU(
                                pred_one_hot[ingr_idx].unsqueeze(0),
                                target_one_hot[ingr_idx].unsqueeze(0)).item()
                            results_dict['ingrs'][imgid[ingr_idx]] = []
                            results_dict['ingrs'][imgid[ingr_idx]].append(
                                fake_ingr)
                            results_dict['ingr_iou'][
                                imgid[ingr_idx]] = iou_item

                    if not args.ingrs_only:
                        sampled_ids_batch = outputs['recipe_ids']
                        sampled_ids_batch = sampled_ids_batch.cpu().detach(
                        ).numpy()

                        for j, sampled_ids in enumerate(sampled_ids_batch):
                            score = compute_score(sampled_ids)
                            if score < th:
                                n_rep += 1
                            if imgid[j] not in captions.keys():
                                results_dict['recipes'][imgid[j]] = []
                                results_dict['recipes'][imgid[j]].append(
                                    sampled_ids)
    if args.get_perplexity:
        print(len(perplexity_list))
        print(np.mean(perplexity_list))
    else:

        if not args.recipe_only:
            ret_metrics = {
                'accuracy': [],
                'f1': [],
                'jaccard': [],
                'f1_ingredients': []
            }
            compute_metrics(ret_metrics,
                            error_types,
                            ['accuracy', 'f1', 'jaccard', 'f1_ingredients'],
                            eps=1e-10,
                            weights=None)

            for k, v in ret_metrics.items():
                print(k, np.mean(v))

        if args.greedy:
            suff = 'greedy'
        else:
            if args.beam != -1:
                suff = 'beam_' + str(args.beam)
            else:
                suff = 'temp_' + str(args.temperature)

        results_file = os.path.join(
            args.save_dir, args.project_name, args.model_name, 'checkpoints',
            args.eval_split + '_' + suff + '_gencaps.pkl')
        print(results_file)
        pickle.dump(results_dict, open(results_file, 'wb'))

        print("Number of samples with excessive repetitions:", n_rep)
Beispiel #3
0
def main(args):

    # Create model directory & other aux folders for logging
    where_to_save = os.path.join(args.save_dir, args.project_name,
                                 args.model_name)
    checkpoints_dir = os.path.join(where_to_save, 'checkpoints')
    logs_dir = os.path.join(where_to_save, 'logs')
    tb_logs = os.path.join(args.save_dir, args.project_name, 'tb_logs',
                           args.model_name)
    make_dir(where_to_save)
    make_dir(logs_dir)
    make_dir(checkpoints_dir)
    make_dir(tb_logs)
    if args.tensorboard:
        logger = Visualizer(tb_logs, name='visual_results')

    # check if we want to resume from last checkpoint of current model
    if args.resume:
        args = pickle.load(
            open(os.path.join(checkpoints_dir, 'args.pkl'), 'rb'))
        args.resume = True

    # logs to disk
    if not args.log_term:
        print("Training logs will be saved to:",
              os.path.join(logs_dir, 'train.log'))
        sys.stdout = open(os.path.join(logs_dir, 'train.log'), 'w')
        sys.stderr = open(os.path.join(logs_dir, 'train.err'), 'w')

    print(args)
    pickle.dump(args, open(os.path.join(checkpoints_dir, 'args.pkl'), 'wb'))

    # patience init
    curr_pat = 0

    # Build data loader
    data_loaders = {}
    datasets = {}

    data_dir = args.recipe1m_dir
    for split in ['train', 'val']:

        transforms_list = [transforms.Resize((args.image_size))]

        if split == 'train':
            # Image preprocessing, normalization for the pretrained resnet
            transforms_list.append(transforms.RandomHorizontalFlip())
            transforms_list.append(
                transforms.RandomAffine(degrees=10, translate=(0.1, 0.1)))
            transforms_list.append(transforms.RandomCrop(args.crop_size))

        else:
            transforms_list.append(transforms.CenterCrop(args.crop_size))
        transforms_list.append(transforms.ToTensor())
        transforms_list.append(
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)))

        transform = transforms.Compose(transforms_list)
        max_num_samples = max(args.max_eval,
                              args.batch_size) if split == 'val' else -1
        data_loaders[split], datasets[split] = get_loader(
            data_dir,
            args.aux_data_dir,
            split,
            args.maxseqlen,
            args.maxnuminstrs,
            args.maxnumlabels,
            args.maxnumims,
            transform,
            args.batch_size,
            shuffle=split == 'train',
            num_workers=args.num_workers,
            drop_last=True,
            max_num_samples=max_num_samples,
            use_lmdb=args.use_lmdb,
            suff=args.suff)

    ingr_vocab_size = datasets[split].get_ingrs_vocab_size()
    instrs_vocab_size = datasets[split].get_instrs_vocab_size()

    # Build the model
    model = get_model(args, ingr_vocab_size, instrs_vocab_size)
    keep_cnn_gradients = False

    decay_factor = 1.0

    # add model parameters
    if args.ingrs_only:
        params = list(model.ingredient_decoder.parameters()) + list(
            model.ingredient_encoder.parameters())
    elif args.recipe_only:
        params = list(model.recipe_decoder.parameters()) + list(
            model.ingredient_encoder.parameters())
    else:
        params = list(model.recipe_decoder.parameters()) + list(model.ingredient_decoder.parameters()) \
                 + list(model.ingredient_encoder.parameters())

    # only train the linear layer in the encoder if we are not transfering from another model
    if args.transfer_from == '':
        params += list(model.image_encoder.linear.parameters())
    params_cnn = list(model.image_encoder.resnet.parameters())

    print("CNN params:", sum(p.numel() for p in params_cnn if p.requires_grad))
    print("decoder params:", sum(p.numel() for p in params if p.requires_grad))
    # start optimizing cnn from the beginning
    if params_cnn is not None and args.finetune_after == 0:
        optimizer = torch.optim.Adam(
            [{
                'params': params
            }, {
                'params': params_cnn,
                'lr': args.learning_rate * args.scale_learning_rate_cnn
            }],
            lr=args.learning_rate,
            weight_decay=args.weight_decay)
        keep_cnn_gradients = True
        print("Fine tuning resnet")
    else:
        optimizer = torch.optim.Adam(params, lr=args.learning_rate)

    if args.resume:
        model_path = os.path.join(args.save_dir, args.project_name,
                                  args.model_name, 'checkpoints', 'model.ckpt')
        optim_path = os.path.join(args.save_dir, args.project_name,
                                  args.model_name, 'checkpoints', 'optim.ckpt')
        optimizer.load_state_dict(torch.load(optim_path, map_location=map_loc))
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(device)
        model.load_state_dict(torch.load(model_path, map_location=map_loc))

    if args.transfer_from != '':
        # loads CNN encoder from transfer_from model
        model_path = os.path.join(args.save_dir, args.project_name,
                                  args.transfer_from, 'checkpoints',
                                  'modelbest.ckpt')
        pretrained_dict = torch.load(model_path, map_location=map_loc)
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items() if 'encoder' in k
        }
        model.load_state_dict(pretrained_dict, strict=False)
        args, model = merge_models(args, model, ingr_vocab_size,
                                   instrs_vocab_size)

    if device != 'cpu' and torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)

    model = model.to(device)
    cudnn.benchmark = True

    if not hasattr(args, 'current_epoch'):
        args.current_epoch = 0

    es_best = 10000 if args.es_metric == 'loss' else 0
    # Train the model
    start = args.current_epoch
    for epoch in range(start, args.num_epochs):

        # save current epoch for resuming
        if args.tensorboard:
            logger.reset()

        args.current_epoch = epoch
        # increase / decrase values for moving params
        if args.decay_lr:
            frac = epoch // args.lr_decay_every
            decay_factor = args.lr_decay_rate**frac
            new_lr = args.learning_rate * decay_factor
            print('Epoch %d. lr: %.5f' % (epoch, new_lr))
            set_lr(optimizer, decay_factor)

        if args.finetune_after != -1 and args.finetune_after < epoch \
                and not keep_cnn_gradients and params_cnn is not None:

            print("Starting to fine tune CNN")
            # start with learning rates as they were (if decayed during training)
            optimizer = torch.optim.Adam([{
                'params': params
            }, {
                'params':
                params_cnn,
                'lr':
                decay_factor * args.learning_rate *
                args.scale_learning_rate_cnn
            }],
                                         lr=decay_factor * args.learning_rate)
            keep_cnn_gradients = True

        for split in ['train', 'val']:

            if split == 'train':
                model.train()
            else:
                model.eval()
            total_step = len(data_loaders[split])
            loader = iter(data_loaders[split])

            total_loss_dict = {
                'recipe_loss': [],
                'ingr_loss': [],
                'eos_loss': [],
                'loss': [],
                'iou': [],
                'perplexity': [],
                'iou_sample': [],
                'f1': [],
                'card_penalty': []
            }

            error_types = {
                'tp_i': 0,
                'fp_i': 0,
                'fn_i': 0,
                'tn_i': 0,
                'tp_all': 0,
                'fp_all': 0,
                'fn_all': 0
            }

            torch.cuda.synchronize()
            start = time.time()

            for i in range(total_step):

                img_inputs, captions, ingr_gt, img_ids, paths = loader.next()

                ingr_gt = ingr_gt.to(device)
                img_inputs = img_inputs.to(device)
                captions = captions.to(device)
                true_caps_batch = captions.clone()[:, 1:].contiguous()
                loss_dict = {}

                if split == 'val':
                    with torch.no_grad():
                        losses = model(img_inputs, captions, ingr_gt)

                        if not args.recipe_only:
                            outputs = model(img_inputs,
                                            captions,
                                            ingr_gt,
                                            sample=True)

                            ingr_ids_greedy = outputs['ingr_ids']

                            mask = mask_from_eos(ingr_ids_greedy,
                                                 eos_value=0,
                                                 mult_before=False)
                            ingr_ids_greedy[mask == 0] = ingr_vocab_size - 1
                            pred_one_hot = label2onehot(
                                ingr_ids_greedy, ingr_vocab_size - 1)
                            target_one_hot = label2onehot(
                                ingr_gt, ingr_vocab_size - 1)
                            iou_sample = softIoU(pred_one_hot, target_one_hot)
                            iou_sample = iou_sample.sum() / (
                                torch.nonzero(iou_sample.data).size(0) + 1e-6)
                            loss_dict['iou_sample'] = iou_sample.item()

                            update_error_types(error_types, pred_one_hot,
                                               target_one_hot)

                            del outputs, pred_one_hot, target_one_hot, iou_sample

                else:
                    losses = model(img_inputs,
                                   captions,
                                   ingr_gt,
                                   keep_cnn_gradients=keep_cnn_gradients)

                if not args.ingrs_only:
                    recipe_loss = losses['recipe_loss']

                    recipe_loss = recipe_loss.view(true_caps_batch.size())
                    non_pad_mask = true_caps_batch.ne(instrs_vocab_size -
                                                      1).float()

                    recipe_loss = torch.sum(recipe_loss * non_pad_mask,
                                            dim=-1) / torch.sum(non_pad_mask,
                                                                dim=-1)
                    perplexity = torch.exp(recipe_loss)

                    recipe_loss = recipe_loss.mean()
                    perplexity = perplexity.mean()

                    loss_dict['recipe_loss'] = recipe_loss.item()
                    loss_dict['perplexity'] = perplexity.item()
                else:
                    recipe_loss = 0

                if not args.recipe_only:

                    ingr_loss = losses['ingr_loss']
                    ingr_loss = ingr_loss.mean()
                    loss_dict['ingr_loss'] = ingr_loss.item()

                    eos_loss = losses['eos_loss']
                    eos_loss = eos_loss.mean()
                    loss_dict['eos_loss'] = eos_loss.item()

                    iou_seq = losses['iou']
                    iou_seq = iou_seq.mean()
                    loss_dict['iou'] = iou_seq.item()

                    card_penalty = losses['card_penalty'].mean()
                    loss_dict['card_penalty'] = card_penalty.item()
                else:
                    ingr_loss, eos_loss, card_penalty = 0, 0, 0

                loss = args.loss_weight[0] * recipe_loss + args.loss_weight[1] * ingr_loss \
                       + args.loss_weight[2]*eos_loss + args.loss_weight[3]*card_penalty

                loss_dict['loss'] = loss.item()

                for key in loss_dict.keys():
                    total_loss_dict[key].append(loss_dict[key])

                if split == 'train':
                    model.zero_grad()
                    loss.backward()
                    optimizer.step()

                # Print log info
                if args.log_step != -1 and i % args.log_step == 0:
                    elapsed_time = time.time() - start
                    lossesstr = ""
                    for k in total_loss_dict.keys():
                        if len(total_loss_dict[k]) == 0:
                            continue
                        this_one = "%s: %.4f" % (
                            k, np.mean(total_loss_dict[k][-args.log_step:]))
                        lossesstr += this_one + ', '
                    # this only displays nll loss on captions, the rest of losses will be in tensorboard logs
                    strtoprint = 'Split: %s, Epoch [%d/%d], Step [%d/%d], Losses: %sTime: %.4f' % (
                        split, epoch, args.num_epochs, i, total_step,
                        lossesstr, elapsed_time)
                    print(strtoprint)

                    if args.tensorboard:
                        # logger.histo_summary(model=model, step=total_step * epoch + i)
                        logger.scalar_summary(
                            mode=split + '_iter',
                            epoch=total_step * epoch + i,
                            **{
                                k: np.mean(v[-args.log_step:])
                                for k, v in total_loss_dict.items() if v
                            })

                    torch.cuda.synchronize()
                    start = time.time()
                del loss, losses, captions, img_inputs

            if split == 'val' and not args.recipe_only:
                ret_metrics = {
                    'accuracy': [],
                    'f1': [],
                    'jaccard': [],
                    'f1_ingredients': [],
                    'dice': []
                }
                compute_metrics(
                    ret_metrics,
                    error_types,
                    ['accuracy', 'f1', 'jaccard', 'f1_ingredients', 'dice'],
                    eps=1e-10,
                    weights=None)

                total_loss_dict['f1'] = ret_metrics['f1']
            if args.tensorboard:
                # 1. Log scalar values (scalar summary)
                logger.scalar_summary(
                    mode=split,
                    epoch=epoch,
                    **{k: np.mean(v)
                       for k, v in total_loss_dict.items() if v})

        # Save the model's best checkpoint if performance was improved
        es_value = np.mean(total_loss_dict[args.es_metric])

        # save current model as well
        save_model(model, optimizer, checkpoints_dir, suff='')
        if (args.es_metric == 'loss'
                and es_value < es_best) or (args.es_metric == 'iou_sample'
                                            and es_value > es_best):
            es_best = es_value
            save_model(model, optimizer, checkpoints_dir, suff='best')
            pickle.dump(args,
                        open(os.path.join(checkpoints_dir, 'args.pkl'), 'wb'))
            curr_pat = 0
            print('Saved checkpoint.')
        else:
            curr_pat += 1

        if curr_pat > args.patience:
            break

    if args.tensorboard:
        logger.close()
Beispiel #4
0
    def forward(self, img_inputs, target_ingrs,target_action,
                sample=False, keep_cnn_gradients=False):
        if sample:
            return self.sample(img_inputs, greedy=True)
        img_features = self.image_encoder(img_inputs, keep_cnn_gradients)

        losses = {}

        #######################################################################
        #这一部分是ingredient生成one-hot
        target_one_hot_ingrs = label2onehot(target_ingrs, self.pad_value_ingrs)
        target_one_hot_smooth_ingrs = label2onehot(target_ingrs, self.pad_value_ingrs)
        #label_smooth
        target_one_hot_smooth_ingrs[target_one_hot_smooth_ingrs == 1] = (1-self.label_smoothing)
        target_one_hot_smooth_ingrs[target_one_hot_smooth_ingrs == 0] = self.label_smoothing / target_one_hot_smooth_ingrs.size(-1)
        ingr_ids, ingr_logits = self.ingredient_decoder.sample(None, None, greedy=True,
                                                               temperature=1.0, img_features=img_features,
                                                               first_token_value=0, replacement=False)
        ingr_logits = torch.nn.functional.softmax(ingr_logits, dim=-1)
        ############################
        #这一部分是ingredient_eos_loss的计算
        ingr_eos = ingr_logits[:, :, 0]
        target_ingr_eos = ((target_ingrs == 0) ^ (target_ingrs == self.pad_value_ingrs))
        target_ingr_eos=target_ingr_eos.float()
        ingr_eos_loss=self.crit(ingr_eos,target_ingr_eos)
        ingr_eos_loss = torch.mean(ingr_eos_loss, dim=-1)
        losses['ingr_eos_loss']=ingr_eos_loss
        #########################################
        # 这一部分是ingredient_loss的计算
        mask_perminv_ingrs = mask_from_eos(target_ingrs, eos_value=0, mult_before=False)
        ingr_probs = ingr_logits * mask_perminv_ingrs.float().unsqueeze(-1)
        ingr_probs, _ = torch.max(ingr_probs, dim=1)
        ingr_ids[mask_perminv_ingrs == 0] = self.pad_value_ingrs
        ingr_loss = self.crit(ingr_probs, target_one_hot_ingrs)
        ingr_loss = torch.mean(ingr_loss, dim=-1)
        losses['ingr_loss'] = ingr_loss

        # iou
        pred_one_hot_ingrs = label2onehot(ingr_ids, self.pad_value_ingrs)
        losses['ingr_iou'] = softIoU(pred_one_hot_ingrs, target_one_hot_ingrs)

        ################################################################################
        # #这一部分是action生成one-hot
        target_one_hot_action = label2onehot(target_action, self.pad_value_action)
        target_one_hot_smooth_action = label2onehot(target_action, self.pad_value_action)
        target_one_hot_smooth_action[target_one_hot_smooth_action == 1] = (1 - self.label_smoothing)
        target_one_hot_smooth_action[target_one_hot_smooth_action == 0] = self.label_smoothing / target_one_hot_smooth_action.size(-1)
        action_ids, action_logits = self.action_decoder.sample(None, None, greedy=True,
                                                               temperature=1.0, img_features=img_features,
                                                               first_token_value=0, replacement=False)
        action_logits = torch.nn.functional.softmax(action_logits, dim=-1)

        ############################
        # 这一部分是action_eos_loss的计算
        action_eos = action_logits[:, :, 0]
        target_action_eos = ((target_action == 0) ^ (target_action == self.pad_value_action))
        target_action_eos=target_action_eos.float()
        action_eos_loss = self.crit(action_eos, target_action_eos)
        action_eos_loss = torch.mean(action_eos_loss, dim=-1)
        losses['action_eos_loss']= action_eos_loss
        ########################################
        # 这一部分是ingredient_loss的计算
        mask_perminv_action = mask_from_eos(target_action, eos_value=0, mult_before=False)
        action_probs = action_logits * mask_perminv_action.float().unsqueeze(-1)
        action_probs, _ = torch.max(action_probs, dim=1)
        action_ids[mask_perminv_action == 0] = self.pad_value_action
        action_loss = self.crit(action_probs, target_one_hot_action)
        action_loss = torch.mean(action_loss, dim=-1)
        losses['action_loss'] = action_loss
        # iou
        pred_one_hot_action = label2onehot(action_ids, self.pad_value_action)
        losses['action_iou'] = softIoU(pred_one_hot_action, target_one_hot_action)

        return losses