def validate(val_loader, decoder, criterion_ce, i2w, device, print_freq, word_map, current_epoch, break_flag, top_x, smoothing_method, print_flag):
    """
    Performs one epoch's validation.
    :param val_loader: DataLoader for validation data.
    :param decoder: decoder model
    :param criterion_ce: cross entropy loss layer
    :param criterion_dis : discriminative loss layer
    :return: BLEU-4 score
    """
    decoder.eval()  # eval mode (no dropout or batchnorm)

    batch_time = AverageMeter()
    losses = AverageMeter()
    top5accs = AverageMeter()

    start = time.time()

    references = list()  # references (true captions) for calculating BLEU-4 score
    hypotheses = list()  # hypotheses (predictions)

    # Batches
    with torch.no_grad():
        for i, data in enumerate(val_loader):

            if break_flag and i == 5:
                break  # only 5 batches

            print('val i', i)
            imgs, caps, caplens, allcaps = data

            # Move to device, if available
            imgs = imgs.to(device)
            caps = caps.to(device)
            caplens = caplens.to(device)

            scores, caps_sorted, decode_lengths, sort_ind = decoder(imgs, caps, caplens)

            # Since we decoded starting with <start>, the targets are all words after <start>, up to <end>
            targets = caps_sorted[:, 1:]

            if print_flag:
                print_predictions(scores, targets, i2w)

            # Remove timesteps that we didn't decode at, or are pads
            # pack_padded_sequence is an easy trick to do this
            scores_copy = scores.clone()
            scores, _ = pack_padded_sequence(scores, decode_lengths, batch_first=True)
            targets, _ = pack_padded_sequence(targets, decode_lengths, batch_first=True)

            # Calculate loss
            loss = criterion_ce(scores, targets)

            # Keep track of metrics
            losses.update(loss.item(), sum(decode_lengths))
            top5 = accuracy(scores, targets, top_x)
            top5accs.update(top5, sum(decode_lengths))
            batch_time.update(time.time() - start)

            start = time.time()

            if i % print_freq == 0:
                print('Validation: [{0}/{1}]\t'
                      'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Top-{topx} Accuracy {top5.val:.3f} ({top5.avg:.3f})\t'.format(i, len(val_loader), batch_time=batch_time,
                                                                                loss=losses, topx=top_x, top5=top5accs))

            # Store references (true captions), and hypothesis (prediction) for each image
            # If for n images, we have n hypotheses, and references a, b, c... for each image, we need -
            # references = [[ref1a, ref1b, ref1c], [ref2a, ref2b], ...], hypotheses = [hyp1, hyp2, ...]

            # References
            allcaps = allcaps[sort_ind]  # because images were sorted in the decoder
            # DIDEC caps of other participants come here

            # print(allcaps.shape)

            for j in range(allcaps.shape[0]):
                img_caps = allcaps[j].tolist()

                img_captions = list(
                    map(lambda c: [w for w in c if w not in {word_map['<start>'], word_map['<pad>']}],
                        img_caps))  # remove <start> and pads

                refs_per_img = []

                for ic in img_captions:
                    if len(ic) > 0:
                        refs_per_img.append(ic)

                references.append(refs_per_img)

            # Hypotheses
            _, preds = torch.max(scores_copy, dim=2)
            # print('scr', scores_copy)
            # print('preds', preds)

            preds = preds.tolist()
            temp_preds = list()
            for j, p in enumerate(preds):
                temp_preds.append(preds[j][:decode_lengths[j]])  # remove pads #SUPERFLUOUS ENDS stay for teacher forcing
            preds = temp_preds
            hypotheses.extend(preds)

            assert len(references) == len(hypotheses)

    # Calculate BLEU-4 scores

    #print('refshyps')
    #print(references)
    #print(hypotheses)

    bleu4 = corpus_bleu(references, hypotheses, smoothing_function=smoothing_method)
    bleu4 = round(bleu4,4)

    print('\n * LOSS - {loss.avg:.3f}, TOP-{topx} ACCURACY - {top5.avg:.3f}, BLEU-4 - {bleu}\n'.format(
            loss=losses,
            topx=top_x,
            top5=top5accs,
            bleu=bleu4))

    return bleu4, losses
         print('Epoch', epoch)
         print('Train')

         # Decay learning rate if there is no improvement for 5 consecutive epochs
         # halved
         # Terminate training after 20
         if epochs_since_improvement == 20:
             break

         #if epochs_since_improvement > 0 and epochs_since_improvement % 10 == 0:
         #    adjust_learning_rate(decoder_optimizer, 0.5)

         decoder.train()
         torch.enable_grad()

         batch_time = AverageMeter()  # forward prop. + back prop. time
         data_time = AverageMeter()  # data loading time
         losses = AverageMeter()  # loss (per word decoded)
         top5accs = AverageMeter()  # top5 accuracy

         start = time.time()

         count = 0

         for i, data in enumerate(training_loader):

            if break_flag and count == 1:
                break

            count += 1