def GetResnet101Features():

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    data_folder = 'C:/Users/paoca/Documents/UVA PHD/NLP/PROJECT/UnnecesaryDataFolder'  # folder with data files saved by create_input_files.py
    data_name = 'coco_5_cap_per_img_5_min_word_freq'

    word_map_file = os.path.join(data_folder, 'WORDMAP_' + data_name + '.json')
    with open(word_map_file, 'r') as j:
        word_map = json.load(j)

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_loader = torch.utils.data.DataLoader(CaptionDataset(
        data_folder,
        data_name,
        'TRAIN',
        transform=transforms.Compose([normalize])),
                                               batch_size=5,
                                               shuffle=False,
                                               pin_memory=True)

    with torch.no_grad():
        encoder = Encoder()
        encoder.fine_tune(False)

        emb_dim = 512
        decoder_dim = 512
        encoderVae_encoder = EncodeVAE_Encoder(embed_dim=emb_dim,
                                               decoder_dim=decoder_dim,
                                               vocab_size=len(word_map))
        encoderVae_encoder.fine_tune(False)

        encoder.eval()
        encoderVae_encoder.eval()

        encoder = encoder.to(device)
        encoderVae_encoder = encoderVae_encoder.to(device)

        for i, (imgs, caps, caplens) in enumerate(train_loader):
            if i % 100 == 0:
                print(i)

            imgs = imgs.to(device)
            caps = caps.to(device)
            caplens = caplens.to(device)

            res = encoder(imgs)
            h = encoderVae_encoder(imgs, caps, caplens)

            pickle.dump(
                res[0].cpu().numpy(),
                open(
                    "C:/Users/paoca/Documents/UVA PHD/NLP/PROJECT/UnnecesaryDataFolder/TrainResnet101Features/"
                    + str(i) + ".p", "wb"))
            pickle.dump(
                h[0].cpu().numpy(),
                open(
                    "C:/Users/paoca/Documents/UVA PHD/NLP/PROJECT/UnnecesaryDataFolder/TrainResnet101Features/VAE_"
                    + str(i) + ".p", "wb"))
def main(imgurl):
    # Load word map (word2ix)
    with open('input_files/WORDMAP.json', 'r') as j:
        word_map = json.load(j)
    rev_word_map = {v: k for k, v in word_map.items()}  # ix2word

    # Load model
    decoder = DecoderWithAttention(attention_dim=attention_dim,
                                   embed_dim=emb_dim,
                                   decoder_dim=decoder_dim,
                                   vocab_size=len(word_map),
                                   dropout=dropout)
    decoder_optimizer = torch.optim.Adam(params=filter(
        lambda p: p.requires_grad, decoder.parameters()),
                                         lr=decoder_lr)
    encoder = Encoder()
    encoder.fine_tune(fine_tune_encoder)
    encoder_optimizer = torch.optim.Adam(
        params=filter(lambda p: p.requires_grad, encoder.parameters()),
        lr=encoder_lr) if fine_tune_encoder else None

    decoder.load_state_dict(
        torch.load('output_files/BEST_checkpoint_decoder.pth.tar'))
    encoder.load_state_dict(
        torch.load('output_files/BEST_checkpoint_encoder.pth.tar'))

    decoder = decoder.to(device)
    decoder.eval()
    encoder = encoder.to(device)
    encoder.eval()

    # Encode, decode with attention and beam search
    seq, alphas = caption_image_beam_search(encoder,
                                            decoder,
                                            imgurl,
                                            word_map,
                                            beam_size=5)
    alphas = torch.FloatTensor(alphas)

    # Visualize caption and attention of best sequence
    # visualize_att(img, seq, alphas, rev_word_map, args.smooth)

    words = [rev_word_map[ind] for ind in seq]
    caption = ' '.join(words[1:-1])
    visualize_att(imgurl, seq, alphas, rev_word_map)
def main():
    global best_bleu4, epochs_since_improvement, checkpoint, start_epoch, fine_tune_encoder, data_name, word_map

    word_map_file = os.path.join(data_folder, 'WORDMAP_' + data_name + '.json') # Loading the wordmap file using dataname(flickr8k)
    with open(word_map_file, 'r') as j:
        word_map = json.load(j)

     if checkpoint is None: # if there is no checkpoint
        decoder = DecoderWithAttention(attention_dim=attention_dim,
                                       embed_dim=embedded_dim,
                                       decoder_dim=decoder_dim,
                                       vocab_size=len(word_map),
                                       dropout=dropout) # using the archi from models file
        decoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, decoder.parameters()),
                                             lr=decoder_lr) # Adam optimizer
        encoder = Encoder() # using the archi from models file
        encoder.fine_tune(fine_tune_encoder) # finetune the encoder
        encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()),
                                             lr=encoder_lr) if fine_tune_encoder else None # Adam optimizer

    else: # load the checkpoint file to continue training 
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        best_bleu4 = checkpoint['bleu-4']
        decoder = checkpoint['decoder']
        decoder_optimizer = checkpoint['decoder_optimizer']
        encoder = checkpoint['encoder']
        encoder_optimizer = checkpoint['encoder_optimizer']
        if fine_tune_encoder is True and encoder_optimizer is None:
            encoder.fine_tune(fine_tune_encoder)
            encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()),
                                                 lr=encoder_lr)
    decoder = decoder.to(device) # converts tensors to CUDA variables if gpu is available  
    encoder = encoder.to(device) # converts tensors to CUDA variables if gpu is available  
    criterion = nn.CrossEntropyLoss().to(device) # Loss function
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], # normalizing the data
                                     std=[0.229, 0.224, 0.225])
    train_loader = torch.utils.data.DataLoader(
        CaptionDataset(data_folder, data_name, 'TRAIN', transform=transforms.Compose([normalize])),
        batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True) # Data loader for train set
    val_loader = torch.utils.data.DataLoader(
        CaptionDataset(data_folder, data_name, 'VAL', transform=transforms.Compose([normalize])),
        batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True) # Data loader for val set

    # Training Starts !!!!!!!
    for epoch in range(start_epoch, epochs):

        if epochs_since_improvement == 20: # Early stopping if the BLEU scores degrade for a long time
            break
        if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:
            adjust_learning_rate(decoder_optimizer, 0.8) # learning rate decay to help the training process
            if fine_tune_encoder:
                adjust_learning_rate(encoder_optimizer, 0.8) # learning rate decay to help the training process
	train(train_loader=train_loader, # Training using the encoder, decoder archi, input images, loss function and optimizers
              encoder=encoder,
              decoder=decoder,
              criterion=criterion,
              encoder_optimizer=encoder_optimizer,
              decoder_optimizer=decoder_optimizer,
              epoch=epoch)
        recent_bleu4 = validate(val_loader=val_loader, # Validation after every epoch
                                encoder=encoder,
                                decoder=decoder,
                                criterion=criterion)

        is_best = recent_bleu4 > best_bleu4
        best_bleu4 = max(recent_bleu4, best_bleu4)
        if not is_best:
            epochs_since_improvement += 1 # If no improvement
            print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement,))
        else:
            epochs_since_improvement = 0

        save_checkpoint(data_name, epoch, epochs_since_improvement, encoder, decoder, encoder_optimizer, # Save checkpoint after every epoch
                        decoder_optimizer, recent_bleu4, is_best)
def main():
    """
    Training and validation.
    """

    global best_bleu4, epochs_since_improvement, checkpoint, start_epoch, fine_tune_encoder, data_name, word_map

    # Read word map (w2i)
    word_map_file = os.path.join(data_folder, 'WORDMAP_' + data_name + '.json')
    with open(word_map_file, 'r') as j:
        word_map = json.load(j)

    # Initialize / load checkpoint
    if checkpoint is None:
        decoder = DecoderWithAttention(attention_dim=attention_dim,
                                       embed_dim=emb_dim,
                                       decoder_dim=decoder_dim,
                                       vocab_size=len(word_map),
                                       dropout=dropout)
        decoder_optimizer = torch.optim.Adam(params=filter(
            lambda p: p.requires_grad, decoder.parameters()),
                                             lr=decoder_lr)
        encoder = Encoder()
        encoder.fine_tune(fine_tune_encoder)
        encoder_optimizer = torch.optim.Adam(
            params=filter(lambda p: p.requires_grad, encoder.parameters()),
            lr=encoder_lr) if fine_tune_encoder else None

    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        best_bleu4 = checkpoint['bleu-4']
        decoder = checkpoint['decoder']
        decoder_optimizer = checkpoint['decoder_optimizer']
        encoder = checkpoint['encoder']
        encoder_optimizer = checkpoint['encoder_optimizer']
        if fine_tune_encoder is True and encoder_optimizer is None:
            encoder.fine_tune(fine_tune_encoder)
            encoder_optimizer = torch.optim.Adam(params=filter(
                lambda p: p.requires_grad, encoder.parameters()),
                                                 lr=encoder_lr)

    # Move to GPU, if available
    decoder = decoder.to(device)
    encoder = encoder.to(device)

    # Loss function
    criterion = nn.CrossEntropyLoss().to(device)

    # Custom dataloaders (This page details the preprocessing or transformation we need to perform –
    # pixel values must be in the range [0,1] and we must then normalize the image by the mean and standard
    # deviation of the ImageNet images' RGB channels.)
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_loader = torch.utils.data.DataLoader(CaptionDataset(
        data_folder,
        data_name,
        'TRAIN',
        transform=transforms.Compose([normalize])),
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(CaptionDataset(
        data_folder,
        data_name,
        'VAL',
        transform=transforms.Compose([normalize])),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=workers,
                                             pin_memory=True)

    initial_time = time.time()
    print("Initial time", initial_time)

    # Epochs
    for epoch in range(start_epoch, epochs):
        print("Starting epoch ", epoch)

        # Decay learning rate if there is no improvement for 8 consecutive epochs, and terminate training after 20
        if epochs_since_improvement == 20:
            break
        if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:
            adjust_learning_rate(decoder_optimizer, 0.8)
            if fine_tune_encoder:
                adjust_learning_rate(encoder_optimizer, 0.8)

        # One epoch's training
        train(train_loader=train_loader,
              encoder=encoder,
              decoder=decoder,
              criterion=criterion,
              encoder_optimizer=encoder_optimizer,
              decoder_optimizer=decoder_optimizer,
              epoch=epoch,
              initial_time=initial_time)

        # One epoch's validation
        recent_bleu4 = validate(val_loader=val_loader,
                                encoder=encoder,
                                decoder=decoder,
                                criterion=criterion)

        # Check if there was an improvement
        is_best = recent_bleu4 > best_bleu4
        best_bleu4 = max(recent_bleu4, best_bleu4)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" %
                  (epochs_since_improvement, ))
        else:
            epochs_since_improvement = 0

        # Save checkpoint
        save_checkpoint(data_name, epoch, epochs_since_improvement, encoder,
                        decoder, encoder_optimizer, decoder_optimizer,
                        recent_bleu4, is_best)
Exemple #5
0
def main():
    """
    Training and validation.
    """

    # In Python, global keyword allows you to modify the variable outside of the current scope.
    # It is used to create a global variable and make changes to the variable in a local context.
    '''
    The basic rules for global keyword in Python are:

    When we create a variable inside a function, it is local by default.
    When we define a variable outside of a function, it is global by default. You don't have to use global keyword.
    We use global keyword to read and write a global variable inside a function.
    Use of global keyword outside a function has no effect.
    
    '''

    global best_bleu4, epochs_since_improvement, checkpoint, start_epoch, fine_tune_encoder, data_name, word_map

    # Read word map
    word_map_file = os.path.join(data_folder, 'WORDMAP_' + data_name + '.json')
    with open(word_map_file, 'r') as j:
        word_map = json.load(j)

    # Initialize / load checkpoint
    if checkpoint is None:
        decoder = DecoderWithAttention(attention_dim=attention_dim,
                                       embed_dim=emb_dim,
                                       decoder_dim=decoder_dim,
                                       vocab_size=len(word_map),
                                       dropout=dropout)
        '''
        The filter() method constructs an iterator from elements of an iterable for which a function returns true.

        The filter() method takes two parameters:

        function - function that tests if elements of an iterable returns true or false
                    If None, the function defaults to Identity function - which returns false if any elements are false

        iterable - iterable which is to be filtered, could be sets, lists, tuples, or containers of any iterators

        The filter() method returns an iterator that passed the function check for each element in the iterable.

        '''

        decoder_optimizer = torch.optim.Adam(params=filter(
            lambda p: p.requires_grad, decoder.parameters()),
                                             lr=decoder_lr)
        encoder = Encoder()
        encoder.fine_tune(fine_tune_encoder)
        encoder_optimizer = torch.optim.Adam(
            params=filter(lambda p: p.requires_grad, encoder.parameters()),
            lr=encoder_lr) if fine_tune_encoder else None

    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        best_bleu4 = checkpoint['bleu-4']
        decoder = checkpoint['decoder']
        decoder_optimizer = checkpoint['decoder_optimizer']
        encoder = checkpoint['encoder']
        encoder_optimizer = checkpoint['encoder_optimizer']
        if fine_tune_encoder is True and encoder_optimizer is None:
            encoder.fine_tune(fine_tune_encoder)
            encoder_optimizer = torch.optim.Adam(params=filter(
                lambda p: p.requires_grad, encoder.parameters()),
                                                 lr=encoder_lr)

    # Move to GPU, if available
    decoder = decoder.to(device)
    encoder = encoder.to(device)

    # Loss function
    criterion = nn.CrossEntropyLoss().to(device)

    # Custom dataloaders
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_loader = torch.utils.data.DataLoader(CaptionDataset(
        data_folder,
        data_name,
        'TRAIN',
        transform=transforms.Compose([normalize])),
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=workers,
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(CaptionDataset(
        data_folder,
        data_name,
        'VAL',
        transform=transforms.Compose([normalize])),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=workers,
                                             pin_memory=True)

    # Epochs
    for epoch in range(start_epoch, epochs):

        # If there's no improvement in Bleu score for 20 epochs then stop training
        if epochs_since_improvement == 20:
            break

        # If there's no improvement in Bleu score for 8 epochs lower the lr
        if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:
            adjust_learning_rate(decoder_optimizer, 0.8)
            if fine_tune_encoder:
                adjust_learning_rate(encoder_optimizer, 0.8)

        # One epoch's training
        train(train_loader=train_loader,
              encoder=encoder,
              decoder=decoder,
              criterion=criterion,
              encoder_optimizer=encoder_optimizer,
              decoder_optimizer=decoder_optimizer,
              epoch=epoch)

        # One epoch's validation
        recent_bleu4 = validate(val_loader=val_loader,
                                encoder=encoder,
                                decoder=decoder,
                                criterion=criterion)

        # Check if there was an improvement
        is_best = recent_bleu4 > best_bleu4
        best_bleu4 = max(recent_bleu4, best_bleu4)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" %
                  (epochs_since_improvement, ))
        else:
            epochs_since_improvement = 0

        # Save checkpoint
        save_checkpoint(data_name, epoch, epochs_since_improvement, encoder,
                        decoder, encoder_optimizer, decoder_optimizer,
                        recent_bleu4, is_best)
def main():
    """
    Training and validation.
    """

    global best_bleu4, epochs_since_improvement, checkpoint, start_epoch, fine_tune_encoder, data_name, word_map, rev_word_map

    # Read word map
    word_map_file = os.path.join(data_folder, 'WORDMAP_' + data_name + '.json')
    with open(word_map_file, 'r') as j:
        word_map = json.load(j)

    rev_word_map = {v: k for k, v in word_map.items()}

    # Initialize / load checkpoint
    if checkpoint is None:
        decoder = DecoderWithAttention(attention_dim=attention_dim,
                                       embed_dim=emb_dim,
                                       decoder_dim=decoder_dim,
                                       vocab_size=len(word_map),
                                       dropout=dropout)
        pretrained_embs, pretrained_embs_dim = load_embeddings(
            '/home/Iwamura/datasets/datasets/GloVe/glove.6B.300d.txt',
            word_map)
        assert pretrained_embs_dim == decoder.embed_dim
        decoder.load_pretrained_embeddings(pretrained_embs)
        decoder.fine_tune_embeddings(True)

        decoder_optimizer = torch.optim.Adam(params=filter(
            lambda p: p.requires_grad, decoder.parameters()),
                                             lr=decoder_lr)
        encoder = Encoder()
        encoder_opt = Encoder()
        encoder.fine_tune(fine_tune_encoder)
        encoder_opt.fine_tune(fine_tune_encoder_opt)
        encoder_optimizer = torch.optim.Adam(
            params=filter(lambda p: p.requires_grad, encoder.parameters()),
            lr=encoder_lr) if fine_tune_encoder else None
        encoder_optimizer_opt = torch.optim.Adam(
            params=filter(lambda p: p.requires_grad, encoder_opt.parameters()),
            lr=encoder_opt_lr) if fine_tune_encoder_opt else None

    else:

        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        best_bleu4 = checkpoint['bleu-4']
        decoder = checkpoint['decoder']
        decoder_optimizer = checkpoint['decoder_optimizer']
        encoder_opt = checkpoint['encoder_opt']
        encoder_optimizer_opt = checkpoint['encoder_optimizer_opt']

        # if fine_tune_encoder is True and encoder_optimizer is None and encoder_optimizer_opt is None
        if fine_tune_encoder_opt is True and encoder_optimizer_opt is None:
            encoder_opt.fine_tune(fine_tune_encoder_opt)

            encoder_optimizer_opt = torch.optim.Adam(params=filter(
                lambda p: p.requires_grad, encoder_opt.parameters()),
                                                     lr=encoder_opt_lr)

    # Move to GPU, if available
    decoder = decoder.to(device)

    encoder_opt = encoder_opt.to(device)

    # Loss function
    criterion = nn.CrossEntropyLoss().to(device)

    # Custom dataloaders

    normalize_opt = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

    train_loader = torch.utils.data.DataLoader(CaptionDataset(
        data_folder,
        data_name,
        'TRAIN',
        transform=transforms.Compose([normalize_opt])),
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(CaptionDataset(
        data_folder,
        data_name,
        'VAL',
        transform=transforms.Compose([normalize_opt])),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=workers,
                                             pin_memory=True)

    # Epochs
    for epoch in range(start_epoch, epochs):

        # Decay learning rate if there is no improvement for 8 consecutive epochs, and terminate training after 20
        if epochs_since_improvement == 10:
            break
        if epoch > 0 and epoch % 4 == 0:
            adjust_learning_rate(decoder_optimizer, 0.8)

            if fine_tune_encoder_opt:
                adjust_learning_rate(encoder_optimizer_opt, 0.8)

        # One epoch's training
        train(train_loader=train_loader,
              encoder_opt=encoder_opt,
              decoder=decoder,
              criterion=criterion,
              encoder_optimizer_opt=encoder_optimizer_opt,
              decoder_optimizer=decoder_optimizer,
              epoch=epoch)

        # One epoch's validation
        recent_bleu4 = validate(val_loader=val_loader,
                                encoder_opt=encoder_opt,
                                decoder=decoder,
                                criterion=criterion)

        # Check if there was an improvement
        is_best = recent_bleu4 > best_bleu4
        best_bleu4 = max(recent_bleu4, best_bleu4)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" %
                  (epochs_since_improvement, ))
        else:
            epochs_since_improvement = 0

        # Save checkpoint
        save_checkpoint(data_name, epoch, epochs_since_improvement,
                        encoder_opt, decoder, encoder_optimizer_opt,
                        decoder_optimizer, recent_bleu4, is_best)
def main():
    """
    Training and validation.
    """

    global best_bleu4, epochs_since_improvement, checkpoint, start_epoch, fine_tune_encoder, data_name, word_map

    # Read word map
    #word_map_file = os.path.join(data_folder, 'WORDMAP_' + data_name + '.json')
    #with open(word_map_file, 'r') as j:
    #    word_map = json.load(j)

    with open("/content/image_captioning/Image-Captioning-Codebase/vocab.pkl",
              "rb") as f:
        vocab = pickle.load(f)

    word_map = vocab.word2idx

    # Initialize / load checkpoint
    if checkpoint is None:
        decoder = DecoderWithAttention(attention_dim=attention_dim,
                                       embed_dim=emb_dim,
                                       decoder_dim=decoder_dim,
                                       vocab_size=len(word_map),
                                       dropout=dropout)
        decoder_optimizer = torch.optim.Adam(params=filter(
            lambda p: p.requires_grad, decoder.parameters()),
                                             lr=decoder_lr)
        encoder = Encoder()
        encoder.fine_tune(fine_tune_encoder)
        encoder_optimizer = torch.optim.Adam(
            params=filter(lambda p: p.requires_grad, encoder.parameters()),
            lr=encoder_lr) if fine_tune_encoder else None

    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        best_bleu4 = checkpoint['bleu-4']
        decoder = checkpoint['decoder']
        decoder_optimizer = checkpoint['decoder_optimizer']
        encoder = checkpoint['encoder']
        encoder_optimizer = checkpoint['encoder_optimizer']
        if fine_tune_encoder is True and encoder_optimizer is None:
            encoder.fine_tune(fine_tune_encoder)
            encoder_optimizer = torch.optim.Adam(params=filter(
                lambda p: p.requires_grad, encoder.parameters()),
                                                 lr=encoder_lr)

    # Move to GPU, if available
    decoder = decoder.to(device)
    encoder = encoder.to(device)

    # Loss function
    criterion = nn.CrossEntropyLoss().to(device)

    # Custom dataloaders
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    transform_train = transforms.Compose(
        [  # smaller edge of image resized to 256
            transforms.Resize(
                (224, 224)),  # get 224x224 crop from random location
            transforms.RandomHorizontalFlip(
            ),  # horizontally flip image with probability=0.5
            transforms.ToTensor(),  # convert the PIL Image to a tensor
            transforms.Normalize(
                (0.485, 0.456, 0.406),  # normalize image for pre-trained model
                (0.229, 0.224, 0.225))
        ])
    """
    train_loader = torch.utils.data.DataLoader(
        CaptionDataset(data_folder, data_name, 'TRAIN', transform=transforms.Compose([normalize])),
        batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)
    val_loader = torch.utils.data.DataLoader(
        CaptionDataset(data_folder, data_name, 'VAL', transform=transforms.Compose([normalize])),
        batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)
    """
    train_loader = torch.utils.data.DataLoader(
        Flickr8kDataset(annot_path="/content/", img_path="/content/Flicker8k_Dataset/", \
                            split="train", transform=transform_train), \
                            batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)
    val_loader = torch.utils.data.DataLoader(
        Flickr8kDataset(annot_path="/content/", img_path="/content/Flicker8k_Dataset/", \
                            split="dev", transform=transform_train), \
                            batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)

    # Epochs
    for epoch in range(start_epoch, epochs):

        # Decay learning rate if there is no improvement for 8 consecutive epochs, and terminate training after 20
        if epochs_since_improvement == 20:
            break
        if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:
            adjust_learning_rate(decoder_optimizer, 0.8)
            if fine_tune_encoder:
                adjust_learning_rate(encoder_optimizer, 0.8)

        # One epoch's training
        train(train_loader=train_loader,
              encoder=encoder,
              decoder=decoder,
              criterion=criterion,
              encoder_optimizer=encoder_optimizer,
              decoder_optimizer=decoder_optimizer,
              epoch=epoch)

        # One epoch's validation
        recent_bleu4 = validate(val_loader=val_loader,
                                encoder=encoder,
                                decoder=decoder,
                                criterion=criterion)

        # Check if there was an improvement
        is_best = recent_bleu4 > best_bleu4
        best_bleu4 = max(recent_bleu4, best_bleu4)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" %
                  (epochs_since_improvement, ))
        else:
            epochs_since_improvement = 0

        # Save checkpoint
        save_checkpoint(data_name, epoch, epochs_since_improvement, encoder,
                        decoder, encoder_optimizer, decoder_optimizer,
                        recent_bleu4, is_best)
Exemple #8
0
def main():
    """
    Training and validation.
    """

    global best_bleu4, epochs_since_improvement, checkpoint, start_epoch, fine_tune_encoder, data_name, word_map, glove_path, emb_dim, rev_word_map

    # Read word map
    word_map_file = os.path.join(data_folder, 'WORDMAP_' + data_name + '.json')
    with open(word_map_file, 'r') as j:
        word_map = json.load(j)

    rev_word_map = {v: k for k, v in word_map.items()}
    #get glove
    vectors = bcolz.open(f'{glove_path}/6B.300.dat')[:]
    words = pickle.load(open(f'{glove_path}/6B.300_words.pkl', 'rb'))
    word2idx = pickle.load(open(f'{glove_path}/6B.300_idx.pkl', 'rb'))

    glove = {w: vectors[word2idx[w]] for w in words}
    matrix_len = len(word_map)
    weights_matrix = np.zeros((matrix_len, emb_dim))
    words_found = 0

    for i, word in enumerate(word_map.keys()):
        try:
            weights_matrix[i] = glove[word]
            words_found += 1
        except KeyError:
            weights_matrix[i] = np.random.normal(scale=0.6, size=(emb_dim, ))
#     weights_matrix = np.float64(weights_matrix)
#     weights_matrix = torch.from_numpy(weights_matrix)
#     pretrained_embedding = weights_matrix.to(dtype=torch.float)
#     print(pretrained_embedding.dtype)
#     if device.type == 'cpu' :
#         pretrained_embedding =  torch.FloatTensor(weights_matrix)
#     else:
#         pretrained_embedding =  torch.cuda.FloatTensor(weights_matrix)
        pretrained_embedding = torch.FloatTensor(weights_matrix)

    # Initialize / load checkpoint
    if checkpoint is None:
        decoder = DecoderWithAttention(attention_dim=attention_dim,
                                       embed_dim=emb_dim,
                                       decoder_dim=decoder_dim,
                                       vocab_size=len(word_map),
                                       dropout=dropout)
        decoder.load_pretrained_embeddings(
            pretrained_embedding
        )  # pretrained_embeddings should be of dimensions (len(word_map), emb_dim)
        decoder.fine_tune_embeddings(True)  # or False
        decoder_optimizer = torch.optim.Adam(params=filter(
            lambda p: p.requires_grad, decoder.parameters()),
                                             lr=decoder_lr)
        encoder = Encoder()
        encoder.fine_tune(fine_tune_encoder)
        encoder_optimizer = torch.optim.Adam(
            params=filter(lambda p: p.requires_grad, encoder.parameters()),
            lr=encoder_lr) if fine_tune_encoder else None

    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        best_bleu4 = checkpoint['bleu-4']
        decoder = checkpoint['decoder']
        decoder_optimizer = checkpoint['decoder_optimizer']
        encoder = checkpoint['encoder']
        encoder_optimizer = checkpoint['encoder_optimizer']
        if fine_tune_encoder is True and encoder_optimizer is None:
            encoder.fine_tune(fine_tune_encoder)
            encoder_optimizer = torch.optim.Adam(params=filter(
                lambda p: p.requires_grad, encoder.parameters()),
                                                 lr=encoder_lr)

    # Move to GPU, if available
    decoder = decoder.to(device)
    encoder = encoder.to(device)

    # Loss function
    criterion = nn.CrossEntropyLoss().to(device)

    # Custom dataloaders
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_loader = torch.utils.data.DataLoader(CaptionDataset(
        data_folder,
        data_name,
        'TRAIN',
        transform=transforms.Compose([normalize])),
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(CaptionDataset(
        data_folder,
        data_name,
        'VAL',
        transform=transforms.Compose([normalize])),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=workers,
                                             pin_memory=True)

    # Epochs
    for epoch in range(start_epoch, epochs):

        # Decay learning rate if there is no improvement for 8 consecutive epochs, and terminate training after 20
        if epochs_since_improvement == 20:
            break
        if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:
            adjust_learning_rate(decoder_optimizer, 0.8)
            if fine_tune_encoder:
                adjust_learning_rate(encoder_optimizer, 0.8)

        # One epoch's training


#         train(train_loader=train_loader,
#               encoder=encoder,
#               decoder=decoder,
#               criterion=criterion,
#               encoder_optimizer=encoder_optimizer,
#               decoder_optimizer=decoder_optimizer,
#               epoch=epoch)

# One epoch's validation
        recent_bleu4 = validate(val_loader=val_loader,
                                encoder=encoder,
                                decoder=decoder,
                                criterion=criterion)

        # Check if there was an improvement
        is_best = recent_bleu4 > best_bleu4
        best_bleu4 = max(recent_bleu4, best_bleu4)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" %
                  (epochs_since_improvement, ))
        else:
            epochs_since_improvement = 0

        # Save checkpoint
        save_checkpoint(data_name, epoch, epochs_since_improvement, encoder,
                        decoder, encoder_optimizer, decoder_optimizer,
                        recent_bleu4, is_best)
def main():
    """
    Training and validation.
    """

    global best, epochs_since_improvement, checkpoint, start_epoch, fine_tune_encoder

    if checkpoint is None:
        decoder = DecoderWithAttention(
            attention_dim=attention_dim,
            embed_dim=emb_dim,
            decoder_dim=decoder_dim,
            # vocab_size=len(word_map),
            vocab_size=2,  # X, Y coordinates and use it for regression
            dropout=dropout)

        decoder_optimizer = torch.optim.Adam(params=filter(
            lambda p: p.requires_grad, decoder.parameters()),
                                             lr=decoder_lr)

        decoderMulti = network.__dict__['MultiTask'](output_size)
        decoderMulti_optimizer = torch.optim.Adam(
            decoderMulti.parameters(),
            lr=decoderMulti_lr,
            weight_decay=decoderMulti_lr_weight_decay)

        encoder = Encoder()
        encoder_optimizer = torch.optim.Adam(
            params=filter(lambda p: p.requires_grad, encoder.parameters()),
            lr=encoder_lr) if fine_tune_encoder else None

    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        best = checkpoint['b4']
        decoder = checkpoint['decoder']
        decoder_optimizer = checkpoint['decoder_optimizer']
        encoder = checkpoint['encoder']
        encoder_optimizer = checkpoint['encoder_optimizer']

        decoderMulti = checkpoint['decoderMulti']
        decoderMulti_optimizer = checkpoint['decoderMulti_optimizer']

        if fine_tune_encoder is True and encoder_optimizer is None:
            encoder.fine_tune(fine_tune_encoder)
            encoder_optimizer = torch.optim.Adam(params=filter(
                lambda p: p.requires_grad, encoder.parameters()),
                                                 lr=encoder_lr)

    # Move to GPU, if available

    if multiGpu:

        decoder = torch.nn.DataParallel(decoder).to(device)
        encoder = torch.nn.DataParallel(encoder).to(device)
        decoderMulti = torch.nn.DataParallel(decoderMulti).to(device)

    else:

        decoder = decoder.to(device)
        encoder = encoder.to(device)
        decoderMulti = decoderMulti.to(device)

    # Loss function
    criterionBinary = nn.BCELoss().to(device)
    criterionMse = nn.MSELoss().to(device)

    criterion = [criterionBinary, criterionMse]

    # Custom dataloaders
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_loader_p = torch.utils.data.DataLoader(fiberDataset_COCO(
        data_folder_p, jason_file_p, image_folder, offset_folder,
        transforms.Compose([transforms.ToTensor(), normalize]), True),
                                                 batch_size=batch_size,
                                                 shuffle=True,
                                                 num_workers=workers,
                                                 pin_memory=True,
                                                 drop_last=True)

    val_loader_p = torch.utils.data.DataLoader(fiberDataset_COCO(
        data_folder_val, jason_file_val, image_folder_val, offset_folder_val,
        transforms.Compose([transforms.ToTensor(), normalize]), True),
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=workers,
                                               pin_memory=True,
                                               drop_last=True)

    # Epochs
    for epoch in range(start_epoch, epochs):

        # Decay learning rate if there is no improvement for 8 consecutive epochs, and terminate training after 20
        if epochs_since_improvement == 20:
            break
        if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:
            adjust_learning_rate(decoder_optimizer, 0.8)
            if fine_tune_encoder:
                adjust_learning_rate(encoder_optimizer, 0.8)

        train(train_loader=train_loader_p,
              encoder=encoder,
              decoder=decoder,
              decoderMulti=decoderMulti,
              criterion=criterion,
              encoder_optimizer=encoder_optimizer,
              decoder_optimizer=decoder_optimizer,
              decoderMulti_optimizer=decoderMulti_optimizer,
              epoch=epoch)

        recent = validate(val_loader=val_loader_p,
                          encoder=encoder,
                          decoder=decoder,
                          decoderMulti=decoderMulti,
                          criterion=criterion)
        #
        # # Check if there was an improvement
        is_best = recent < best
        best = min(recent, best)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" %
                  (epochs_since_improvement, ))
        else:
            epochs_since_improvement = 0
            epochs_since_improvement = 0

            # Save checkpoint
            save_checkpoint(save_weights_name, epoch, epochs_since_improvement,
                            encoder, decoder, encoder_optimizer,
                            decoder_optimizer, decoderMulti,
                            decoderMulti_optimizer, best, is_best)
def main():
    """
    Training and validation.
    """

    global best_bleu4, epochs_since_improvement, checkpoint, start_epoch, fine_tune_encoder, data_name, word_map, lowest_loss_val

    # Read word map
    word_map_file = os.path.join(data_folder, 'WORDMAP_' + data_name + '.json')
    with open(word_map_file, 'r') as j:
        word_map = json.load(j)

    # Train N models and save them to each directory
    for n in range(1, args.num_models + 1):
        # Directory where the model will be saved
        model_out = os.path.join(args.model, "model_{}".format(n))
        try:
            os.mkdir(model_out)
        except:
            pass
        # Initialize / load checkpoint
        if checkpoint is None:
            decoder = DecoderWithAttention(attention_dim=attention_dim,
                                           embed_dim=emb_dim,
                                           decoder_dim=decoder_dim,
                                           vocab_size=len(word_map),
                                           dropout=dropout)
            decoder_optimizer = torch.optim.Adam(params=filter(
                lambda p: p.requires_grad, decoder.parameters()),
                                                 lr=decoder_lr)
            encoder = Encoder()
            encoder.fine_tune(fine_tune_encoder)
            encoder_optimizer = torch.optim.Adam(
                params=filter(lambda p: p.requires_grad, encoder.parameters()),
                lr=encoder_lr) if fine_tune_encoder else None
        else:
            checkpoint = torch.load(checkpoint)
            start_epoch = checkpoint['epoch'] + 1
            epochs_since_improvement = checkpoint['epochs_since_improvement']
            best_bleu4 = checkpoint['bleu-4']
            decoder = checkpoint['decoder']
            decoder_optimizer = checkpoint['decoder_optimizer']
            encoder = checkpoint['encoder']
            encoder_optimizer = checkpoint['encoder_optimizer']
            if fine_tune_encoder is True and encoder_optimizer is None:
                encoder.fine_tune(fine_tune_encoder)
                encoder_optimizer = torch.optim.Adam(params=filter(
                    lambda p: p.requires_grad, encoder.parameters()),
                                                     lr=encoder_lr)

        # Move to GPU, if available
        decoder = decoder.to(device)
        encoder = encoder.to(device)

        # Loss function
        criterion = nn.CrossEntropyLoss().to(device)

        # Custom dataloaders
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
        train_loader = torch.utils.data.DataLoader(CaptionDataset(
            data_folder,
            data_name,
            'TRAIN',
            transform=transforms.Compose([normalize])),
                                                   batch_size=batch_size,
                                                   shuffle=True,
                                                   num_workers=workers,
                                                   pin_memory=True)
        val_loader = torch.utils.data.DataLoader(CaptionDataset(
            data_folder,
            data_name,
            'VAL',
            transform=transforms.Compose([normalize])),
                                                 batch_size=batch_size,
                                                 shuffle=True,
                                                 num_workers=workers,
                                                 pin_memory=True)

        # Epochs
        for epoch in range(start_epoch, epochs):
            # Decay learning rate if there is no improvement for 20 consecutive epochs
            # and terminate training after 50 consecutive epochs
            if epochs_since_improvement == 50:
                break
            if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:
                adjust_learning_rate(decoder_optimizer, 0.8)
                if fine_tune_encoder:
                    adjust_learning_rate(encoder_optimizer, 0.8)

            # One epoch's training
            train(train_loader=train_loader,
                  encoder=encoder,
                  decoder=decoder,
                  criterion=criterion,
                  encoder_optimizer=encoder_optimizer,
                  decoder_optimizer=decoder_optimizer,
                  epoch=epoch)

            # One epoch's validation
            recent_loss, recent_bleu4 = validate(val_loader=val_loader,
                                                 encoder=encoder,
                                                 decoder=decoder,
                                                 criterion=criterion)
            # Check if there was an improvement using bleu
            is_best = recent_bleu4 > best_bleu4
            best_bleu4 = max(recent_bleu4, best_bleu4)
            # Check if there was an improvement using loss
            #is_best = recent_loss < lowest_loss_val
            #lowest_loss_val = min(recent_loss, lowest_loss_val)
            if not is_best:
                epochs_since_improvement += 1
                print("\nEpochs since last improvement: %d\n" %
                      (epochs_since_improvement, ))
            else:
                epochs_since_improvement = 0

            #save_checkpoint_with_dir(model_out, data_name, epoch, epochs_since_improvement,
            #                         encoder, decoder, encoder_optimizer,
            #                         decoder_optimizer, lowest_loss_val, is_best)
            save_checkpoint_with_dir(model_out, data_name, epoch,
                                     epochs_since_improvement, encoder,
                                     decoder, encoder_optimizer,
                                     decoder_optimizer, recent_bleu4, is_best)
        # Delete encoder&decoder objects and reset memory
        del decoder
        del encoder
        torch.cuda.empty_cache()
        # Reset epochs since improvement to 0 for a new round of training
        epochs_since_improvement = 0
        best_bleu4, start_epoch = 0, 0
        check_point = None
        fine_tune_encoder = False
Exemple #11
0
def main():
    """
    Training and validation.
    """

    global best_bleu4, epochs_since_improvement, checkpoint, start_epoch, fine_tune_encoder, data_name, word_map

    # Read word map
    word_map_file = os.path.join(data_folder, 'WORDMAP_' + data_name + '.json')
    with open(word_map_file, 'r') as j:
        word_map = json.load(j)

    # Load Pretrained Embeddings and compare to Wordmap if True, otherwise reload the pickle file
    if reload_pretrained_embed == True:
        embeddings_index = dict()
        fid = open(pretrained_embeddings_file, encoding="utf8")
        for line in fid:
            values = line.split()
            word = values[0]
            coefs = np.asarray(values[1:], dtype='float32')
            embeddings_index[word] = coefs
        fid.close()

        pretrained_embeddings = torch.zeros((len(word_map) + 1, emb_dim))

        for word, idx in word_map.items():
            embed_vector = embeddings_index.get(word)
            if embed_vector is not None:
                # words not found in embedding index will be all-zeros.
                pretrained_embeddings[idx] = torch.from_numpy(embed_vector)
            else:
                pretrained_embeddings[idx] = torch.from_numpy(
                    np.random.uniform(-1, 1, emb_dim))

    #   print(pretrained_embeddings[0:2, :])

    #   fid = open("embedding_matrix.pkl","wb")
    #   dump(pretrained_embeddings, fid)
    #   fid.close()

    # else:
    #   pretrained_embeddings = open(pretrained_embedding_matrix, "wb")
    #   print('Successfully Loaded Pretrained Embeddings Pickle')

    # Initialize / load checkpoint
    if checkpoint is None:
        decoder = DecoderWithAttention(attention_dim=attention_dim,
                                       embed_dim=emb_dim,
                                       decoder_dim=decoder_dim,
                                       vocab_size=len(word_map),
                                       dropout=dropout)
        # decoder.load_pretrained_embeddings(pretrained_embeddings)  # pretrained_embeddings should be of dimensions (len(word_map), emb_dim)
        # decoder.fine_tune_embeddings(True)
        decoder_optimizer = torch.optim.Adam(params=filter(
            lambda p: p.requires_grad, decoder.parameters()),
                                             lr=args['decoder_lr'])

        # encoder = vgg_face_dag() #VGG Face
        encoder = Encoder()  #OG Encoder
        # encoder.cuda()
        # print(summary(encoder, (3, 224, 224)))
        # print('ENCODER SUMMARY')
        encoder.fine_tune(fine_tune_encoder)
        encoder_optimizer = torch.optim.Adam(
            params=filter(lambda p: p.requires_grad, encoder.parameters()),
            lr=args['encoder_lr']) if fine_tune_encoder else None

    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        best_bleu4 = checkpoint['bleu-4']
        decoder = checkpoint['decoder']
        decoder_optimizer = checkpoint['decoder_optimizer']
        encoder = checkpoint['encoder']
        encoder_optimizer = checkpoint['encoder_optimizer']
        if fine_tune_encoder is True and encoder_optimizer is None:
            encoder.fine_tune(fine_tune_encoder)
            encoder_optimizer = torch.optim.Adam(params=filter(
                lambda p: p.requires_grad, encoder.parameters()),
                                                 lr=args['encoder_lr'])

    # Move to GPU, if available
    decoder = decoder.to(device)
    encoder = encoder.to(device)

    # Loss function
    criterion = nn.CrossEntropyLoss().to(device)

    # Custom dataloaders
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])  #OG figures
    # normalize = transforms.Normalize(mean= [129.186279296875, 104.76238250732422, 93.59396362304688], #VGG Face figures
    #                                  std= [1, 1, 1])

    train_loader = torch.utils.data.DataLoader(CaptionDataset(
        data_folder,
        data_name,
        'TRAIN',
        transform=transforms.Compose([normalize])),
                                               batch_size=args['batch_size'],
                                               shuffle=True,
                                               num_workers=workers,
                                               pin_memory=True)

    # print('validation_loader')
    val_loader = torch.utils.data.DataLoader(CaptionDataset(
        data_folder,
        data_name,
        'VAL',
        transform=transforms.Compose([normalize])),
                                             batch_size=args['batch_size'],
                                             shuffle=True,
                                             num_workers=workers,
                                             pin_memory=True)

    # Epochs
    for epoch in range(start_epoch, args['epochs']):

        # Decay learning rate if there is no improvement for 8 consecutive epochs, and terminate training after 15
        if epochs_since_improvement == 15:
            break
        if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:
            adjust_learning_rate(decoder_optimizer, 0.8)
            if fine_tune_encoder:
                adjust_learning_rate(encoder_optimizer, 0.8)

        # One epoch's training
        train(train_loader=train_loader,
              encoder=encoder,
              decoder=decoder,
              criterion=criterion,
              encoder_optimizer=encoder_optimizer,
              decoder_optimizer=decoder_optimizer,
              epoch=epoch)

        # One epoch's validation
        # print('validation_loader_2')
        recent_bleu4 = validate(val_loader=val_loader,
                                encoder=encoder,
                                decoder=decoder,
                                criterion=criterion,
                                epoch=epoch)

        # Check if there was an improvement
        is_best = recent_bleu4 > best_bleu4
        best_bleu4 = max(recent_bleu4, best_bleu4)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" %
                  (epochs_since_improvement, ))
        else:
            epochs_since_improvement = 0
        tensorboard_writer.add_scalar('BLEU-4/epoch', recent_bleu4, epoch)

        # Save checkpoint
        save_checkpoint(data_name, epoch, epochs_since_improvement, encoder,
                        decoder, encoder_optimizer, decoder_optimizer,
                        recent_bleu4, is_best)

    PATH = './cifar_net.pth'
    tensorboard_writer.close()

    print('Task ID number is: {}'.format(task.id))
Exemple #12
0
def fit(t_params, checkpoint=None, m_params=None):

    # info
    data_name = t_params['data_name']
    imgs_path = t_params['imgs_path']
    df_path = t_params['df_path']
    vocab = t_params['vocab']

    start_epoch = 0
    epochs_since_improvement = 0
    best_bleu4 = 0
    epochs = t_params['epochs']
    batch_size = t_params['batch_size']
    workers = t_params['workers']
    encoder_lr = t_params['encoder_lr']
    decoder_lr = t_params['decoder_lr']
    fine_tune_encoder = t_params['fine_tune_encoder']

    # init / load checkpoint
    if checkpoint is None:

        # getting hyperparameters
        attention_dim = m_params['attention_dim']
        embed_dim = m_params['embed_dim']
        decoder_dim = m_params['decoder_dim']
        encoder_dim = m_params['encoder_dim']
        dropout = m_params['dropout']

        decoder = DecoderWithAttention(attention_dim=attention_dim,
                                       embed_dim=embed_dim,
                                       decoder_dim=decoder_dim,
                                       encoder_dim=encoder_dim,
                                       vocab_size=len(vocab),
                                       dropout=dropout)
        decoder_optimizer = torch.optim.Adam(params=filter(
            lambda p: p.requires_grad, decoder.parameters()),
                                             lr=decoder_lr)

        encoder = Encoder()
        encoder.fine_tune(fine_tune_encoder)
        encoder_optimizer = torch.optim.Adam(
            params=filter(lambda p: p.requires_grad, encoder.parameters()),
            lr=encoder_lr) if fine_tune_encoder else None
    # load checkpoint
    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        best_bleu4 = checkpoint['bleu-4']
        decoder = checkpoint['decoder']
        decoder_optimizer = checkpoint['decoder_optimizer']
        encoder = checkpoint['encoder']
        encoder_optimizer = checkpoint['encoder_optimizer']
        if fine_tune_encoder is True and encoder_optimizer is None:
            encoder.fine_tune(fine_tune_encoder)
            encoder_optimizer = torch.optim.Adam(params=filter(
                lambda p: p.requires_grad, encoder.parameters()),
                                                 lr=encoder_lr)

    # move to gpu, if available
    decoder = decoder.to(device)
    encoder = encoder.to(device)

    # loss function
    criterion = nn.CrossEntropyLoss().to(device)

    # dataloaders
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])

    print('Loading Data')
    train_loader, val_loader = get_loaders(batch_size, imgs_path, df_path,
                                           transform, vocab, workers)
    print('_' * 50)

    print('-' * 20, 'Fitting', '-' * 20)
    for epoch in range(start_epoch, epochs):

        # decay lr is there is no improvement for 8 consecutive epochs and terminate after 20
        if epochs_since_improvement == 20:
            print('No improvement for 20 consecutive epochs, terminating...')
            break
        if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:
            adjust_learning_rate(decoder_optimizer, 0.8)
            if fine_tune_encoder:
                adjust_learning_rate(encoder_optimizer, 0.8)

        print('_' * 50)
        print('-' * 20, 'Training', '-' * 20)
        # one epoch of training
        train(train_loader=train_loader,
              encoder=encoder,
              decoder=decoder,
              criterion=criterion,
              encoder_optimizer=encoder_optimizer,
              decoder_optimizer=decoder_optimizer,
              epoch=epoch)

        # one epoch of validation
        print('-' * 20, 'Validation', '-' * 20)
        recent_bleu4 = validate(val_loader=val_loader,
                                encoder=encoder,
                                decoder=decoder,
                                criterion=criterion)

        # check for improvement
        is_best = recent_bleu4 > best_bleu4
        best_bleu4 = max(recent_bleu4, best_bleu4)
        if not is_best:
            epochs_since_improvement += 1
            print(
                f'\nEpochs since last improvement: {epochs_since_improvement,}'
            )
        else:
            # reset
            epochs_since_improvement = 0

        save_checkpoint(data_name, epoch, epochs_since_improvement, encoder,
                        decoder, encoder_optimizer, decoder_optimizer,
                        recent_bleu4, is_best)
Exemple #13
0
def main():
    """
    Training and validation.
    """

    global best_bleu4, epochs_since_improvement, checkpoint, start_epoch, fine_tune_encoder, data_name, word_map

    # Read word map
    word_map_file = os.path.join(data_folder, 'WORDMAP_' + data_name + '.json')
    with open(word_map_file, 'r') as j:
        word_map = json.load(j)

    # Initialize / load checkpoint
    if use_sam:
        decoder = DecoderWithAttention(attention_dim=attention_dim,
                                       embed_dim=emb_dim,
                                       decoder_dim=decoder_dim,
                                       vocab_size=len(word_map),
                                       dropout=dropout,
                                       use_glove=use_glove,
                                       word_map=word_map)
        base_optimizer = torch.optim.SGD
        decoder_optimizer = SAM(filter(lambda p: p.requires_grad,
                                       decoder.parameters()),
                                base_optimizer,
                                lr=decoder_lr,
                                momentum=0.9)

        checkpoint = torch.load(checkpoint)
        encoder = checkpoint['encoder']
        encoder_optimizer = None
        print("Loading best encoder but random decoder and using SAM...")

    elif checkpoint is None:
        decoder = DecoderWithAttention(attention_dim=attention_dim,
                                       embed_dim=emb_dim,
                                       decoder_dim=decoder_dim,
                                       vocab_size=len(word_map),
                                       dropout=dropout,
                                       use_glove=use_glove,
                                       word_map=word_map)
        decoder_optimizer = torch.optim.Adam(params=filter(
            lambda p: p.requires_grad, decoder.parameters()),
                                             lr=decoder_lr)
        encoder = Encoder()
        encoder.fine_tune(fine_tune_encoder)
        encoder_optimizer = torch.optim.Adam(
            params=filter(lambda p: p.requires_grad, encoder.parameters()),
            lr=encoder_lr) if fine_tune_encoder else None

    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        print(f"Continuing training from epoch {start_epoch}...")
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        best_bleu4 = checkpoint['bleu-4']
        decoder = checkpoint['decoder']
        if use_sam:
            lr = checkpoint['decoder_optimizer'].param_groups[0]['lr']
            base_optimizer = torch.optim.SGD
            decoder_optimizer = SAM(filter(lambda p: p.requires_grad,
                                           decoder.parameters()),
                                    base_optimizer,
                                    lr=lr,
                                    momentum=0.9)
        else:
            decoder_optimizer = checkpoint['decoder_optimizer']
        encoder = checkpoint['encoder']

        if use_sam and fine_tune_encoder is True:
            lr = checkpoint['encoder_optimizer'].param_groups[0]['lr']
            base_optimizer = torch.optim.SGD
            encoder_optimizer = SAM(filter(lambda p: p.requires_grad,
                                           encoder.parameters()),
                                    base_optimizer,
                                    lr=lr,
                                    momentum=0.9)
        else:
            encoder_optimizer = checkpoint['encoder_optimizer']
        if fine_tune_encoder is True and encoder_optimizer is None:
            encoder.fine_tune(fine_tune_encoder)
            if use_sam:
                base_optimizer = torch.optim.SGD
                encoder_optimizer = SAM(filter(lambda p: p.requires_grad,
                                               encoder.parameters()),
                                        base_optimizer,
                                        lr=encoder_lr,
                                        momentum=0.9)
            else:
                encoder_optimizer = torch.optim.Adam(params=filter(
                    lambda p: p.requires_grad, encoder.parameters()),
                                                     lr=encoder_lr)

    # Move to GPU, if available
    decoder = decoder.to(device)
    encoder = encoder.to(device)

    # Loss function
    criterion = nn.CrossEntropyLoss().to(device)

    # initialize dataloaders
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_loader = torch.utils.data.DataLoader(CocoCaptionDataset(
        data_folder,
        data_name,
        'TRAIN',
        transforms=transforms.Compose([normalize])),
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(CocoCaptionDataset(
        data_folder,
        data_name,
        'VAL',
        transforms=transforms.Compose([normalize])),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=workers,
                                             pin_memory=True)

    print(f"Train dataloader len: {len(train_loader)}")
    print(f"Val dataloader len: {len(val_loader)}")

    # set up tensorbaord
    train_writer = SummaryWriter(
        os.path.join(log_directory, f"{log_name}/train"))
    val_writer = SummaryWriter(os.path.join(log_directory, f"{log_name}/val"))

    # Epochs
    for epoch in tqdm(range(start_epoch, epochs)):

        # Decay learning rate if there is no improvement for 8 consecutive epochs, and terminate training after 20
        if epochs_since_improvement == 20:
            break
        if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:
            adjust_learning_rate(decoder_optimizer, 0.8)
            if fine_tune_encoder:
                adjust_learning_rate(encoder_optimizer, 0.8)

        # One epoch's training
        train(train_loader=train_loader,
              encoder=encoder,
              decoder=decoder,
              criterion=criterion,
              encoder_optimizer=encoder_optimizer,
              decoder_optimizer=decoder_optimizer,
              epoch=epoch,
              train_writer=train_writer)

        # One epoch's validation
        recent_bleu4, val_loss, val_top5_acc = validate(val_loader=val_loader,
                                                        encoder=encoder,
                                                        decoder=decoder,
                                                        criterion=criterion)
        val_writer.add_scalar('Epoch loss', val_loss, epoch + 1)
        val_writer.add_scalar('Epoch top-5 accuracy', val_top5_acc, epoch + 1)
        val_writer.add_scalar('BLEU-4', recent_bleu4, epoch + 1)

        # Check if there was an improvement
        is_best = recent_bleu4 > best_bleu4
        best_bleu4 = max(recent_bleu4, best_bleu4)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" %
                  (epochs_since_improvement, ))
        else:
            epochs_since_improvement = 0

        # Save checkpoint
        checkpoint_name = data_name
        if use_glove:
            checkpoint_name = f"glove_{checkpoint_name}"
        if use_sam:
            checkpoint_name = f"sam_{checkpoint_name}"
        save_checkpoint(checkpoint_name, epoch, epochs_since_improvement,
                        encoder, decoder, encoder_optimizer, decoder_optimizer,
                        recent_bleu4, is_best, checkpoint_path)
Exemple #14
0
def main():
    """
    Describe main process including train and validation.
    """

    global start_epoch, checkpoint, fine_tune_encoder, best_bleu4, epochs_since_improvement, word_map

    # Read word map
    word_map_path = os.path.join(data_folder,
                                 'WORDMAP_' + dataset_name + ".json")
    with open(word_map_path, 'r') as j:
        word_map = json.load(j)

    # Set checkpoint or read from checkpoint
    if checkpoint is None:  # No pretrained model, set model from beginning
        decoder = Decoder(embed_dim=embed_dim,
                          decoder_dim=decoder_dim,
                          vocab_size=len(word_map),
                          dropout=dropout_rate)
        decoder_param = filter(lambda p: p.requires_grad, decoder.parameters())
        for param in decoder_param:
            tensor0 = param.data
            dist.all_reduce(tensor0, op=dist.reduce_op.SUM)
            param.data = tensor0 / np.sqrt(np.float(num_nodes))
        decoder_optimizer = optim.Adam(params=decoder_param, lr=decoder_lr)
        encoder = Encoder()
        encoder.fine_tune(fine_tune_encoder)
        encoder_param = filter(lambda p: p.requires_grad, encoder.parameters())
        if fine_tune_encoder:
            for param in encoder_param:
                tensor0 = param.data
                dist.all_reduce(tensor0, op=dist.reduce_op.SUM)
                param.data = tensor0 / np.sqrt(np.float(num_nodes))
        encoder_optimizer = optim.Adam(
            params=encoder_param, lr=encoder_lr) if fine_tune_encoder else None
    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint["epoch"] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        best_bleu4 = checkpoint['bleu-4']
        decoder = checkpoint['decoder']
        #decoder_optimizer = checkpoint['decoder_optimizer']
        encoder = checkpoint['encoder']
        #encoder_optimizer = checkpoint['encoder_optimizer']
        if fine_tune_encoder and encoder_optimizer is None:
            encoder.fine_tune(fine_tune_encoder)
            encoder_optimizer = torch.optim.Adam(params=filter(
                lambda p: p.requires_grad, encoder.parameters()),
                                                 lr=encoder_lr)

    decoder = decoder.to(device)
    encoder = encoder.to(device)
    criterion = nn.CrossEntropyLoss()

    # Data loader
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_set = CaptionDataset(data_folder=h5data_folder,
                               data_name=dataset_name,
                               split="TRAIN",
                               transform=transforms.Compose([normalize]))
    val_set = CaptionDataset(data_folder=h5data_folder,
                             data_name=dataset_name,
                             split="VAL",
                             transform=transforms.Compose([normalize]))
    train_loader = DataLoader(train_set,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=workers,
                              pin_memory=True)
    val_loader = DataLoader(val_set,
                            batch_size=batch_size,
                            shuffle=True,
                            num_workers=workers,
                            pin_memory=True)

    total_start_time = datetime.datetime.now()
    print("Start the 1st epoch at: ", total_start_time)

    # Epoch
    for epoch in range(start_epoch, num_epochs):
        # Pre-check by epochs_since_improvement
        if epochs_since_improvement == 20:  # If there are 20 epochs that no improvements are achieved
            break
        if epochs_since_improvement % 8 == 0 and epochs_since_improvement > 0:
            adjust_learning_rate(decoder_optimizer)
            if fine_tune_encoder:
                adjust_learning_rate(encoder_optimizer)

        # For every batch
        batch_time = AverageMeter()  # forward prop. + back prop. time
        data_time = AverageMeter()  # data loading time
        losses = AverageMeter()  # loss (per word decoded)
        top5accs = AverageMeter()  # top5 accuracy
        decoder.train()
        encoder.train()

        start = time.time()
        start_time = datetime.datetime.now(
        )  # Initialize start time for this epoch

        # TRAIN
        for j, (images, captions, caplens) in enumerate(train_loader):
            if fine_tune_encoder and (epoch - start_epoch > 0 or j > 10):
                for group in encoder_optimizer.param_groups:
                    for p in group['params']:
                        state = encoder_optimizer.state[p]
                        if (state['step'] >= 1024):
                            state['step'] = 1000

            if (epoch - start_epoch > 0 or j > 10):
                for group in decoder_optimizer.param_groups:
                    for p in group['params']:
                        state = decoder_optimizer.state[p]
                        if (state['step'] >= 1024):
                            state['step'] = 1000

            data_time.update(time.time() - start)

            images = images.to(device)
            captions = captions.to(device)
            caplens = caplens.to(device)
            # Forward
            enc_images = encoder(images)
            predictions, enc_captions, dec_lengths, sort_ind = decoder(
                enc_images, captions, caplens)

            # Define target as original captions excluding <start>
            target = enc_captions[:, 1:]  # (batch_size, max_caption_length-1)
            target, _ = pack_padded_sequence(
                target, dec_lengths, batch_first=True
            )  # Delete all paddings and concat all other parts
            predictions, _ = pack_padded_sequence(
                predictions, dec_lengths,
                batch_first=True)  # (batch_size, sum(dec_lengths))

            loss = criterion(predictions, target)

            # Backward
            decoder_optimizer.zero_grad()
            if encoder_optimizer is not None:
                encoder_optimizer.zero_grad()
            loss.backward()
            ## Clip gradients
            if grad_clip is not None:
                clip_gradient(decoder_optimizer, grad_clip)
                if encoder_optimizer is not None:
                    clip_gradient(encoder_optimizer, grad_clip)
            ## Update
            decoder_optimizer.step()
            if encoder_optimizer is not None:
                encoder_optimizer.step()

            # Update metrics (AverageMeter)
            acc_top5 = compute_accuracy(predictions, target, k=5)
            top5accs.update(acc_top5, sum(dec_lengths))
            losses.update(loss.item(), sum(dec_lengths))
            batch_time.update(time.time() - start)

            # Print current status
            if (j + 1) % print_freq == 0:
                print(
                    'Epoch: [{0}][{1}/{2}]\t'
                    'Current Batch Time: {batch_time.val:.3f} (Average: {batch_time.avg:.3f})\t'
                    'Current Data Load Time: {data_time.val:.3f} (Average: {data_time.avg:.3f})\t'
                    'Current Loss: {loss.val:.4f} (Average: {loss.avg:.4f})\t'
                    'Current Top-5 Accuracy: {top5.val:.3f} (Average: {top5.avg:.3f})'
                    .format(epoch + 1,
                            j + 1,
                            len(train_loader),
                            batch_time=batch_time,
                            data_time=data_time,
                            loss=losses,
                            top5=top5accs))
                now_time = datetime.datetime.now()
                print("Epoch Training Time: ", now_time - start_time)
                print("Total Time: ", now_time - total_start_time)

            start = time.time()

        # VALIDATION
        decoder.eval()
        encoder.eval()

        batch_time = AverageMeter()  # forward prop. + back prop. time
        losses = AverageMeter()  # loss (per word decoded)
        top5accs = AverageMeter()  # top5 accuracy
        references = list(
        )  # references (true captions) for calculating BLEU-4 score
        hypotheses = list()  # hypotheses (predictions)

        start_time = datetime.datetime.now()

        for j, (images, captions, caplens, all_caps) in enumerate(val_loader):
            start = time.time()

            images = images.to(device)
            captions = captions.to(device)
            caplens = caplens.to(device)

            # Forward
            enc_images = encoder(images)
            predictions, enc_captions, dec_lengths, sort_ind = decoder(
                enc_images, captions, caplens)

            # Define target as original captions excluding <start>
            predictions_copy = predictions.clone()
            target = enc_captions[:, 1:]  # (batch_size, max_caption_length-1)
            target, _ = pack_padded_sequence(
                target, dec_lengths, batch_first=True
            )  # Delete all paddings and concat all other parts
            predictions, _ = pack_padded_sequence(
                predictions, dec_lengths,
                batch_first=True)  # (batch_size, sum(dec_lengths))

            loss = criterion(predictions, target)

            # Update metrics (AverageMeter)
            acc_top5 = compute_accuracy(predictions, target, k=5)
            top5accs.update(acc_top5, sum(dec_lengths))
            losses.update(loss.item(), sum(dec_lengths))
            batch_time.update(time.time() - start)

            # Print current status
            if (j + 1) % print_freq == 0:
                print(
                    'Epoch: [{0}][{1}/{2}]\t'
                    'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                    'Data Load Time {data_time.val:.3f} ({data_time.avg:.3f})\t'
                    'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                    'Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})'.format(
                        epoch + 1,
                        j,
                        len(val_loader),
                        batch_time=batch_time,
                        data_time=data_time,
                        loss=losses,
                        top5=top5accs))
                now_time = datetime.datetime.now()
                print("Epoch Validation Time: ", now_time - start_time)
                print("Total Time: ", now_time - total_start_time)

            ## Store references (true captions), and hypothesis (prediction) for each image
            ## If for n images, we have n hypotheses, and references a, b, c... for each image, we need -
            ## references = [[ref1a, ref1b, ref1c], [ref2a, ref2b], ...], hypotheses = [hyp1, hyp2, ...]

            # references
            all_caps = all_caps[sort_ind]
            for k in range(all_caps.shape[0]):
                img_caps = all_caps[k].tolist()
                img_captions = list(
                    map(
                        lambda c: [
                            w for w in c if w not in
                            {word_map["<start>"], word_map["<pad>"]}
                        ], img_caps))
                references.append(img_captions)

            # hypotheses
            _, preds = torch.max(predictions_copy, dim=2)
            preds = preds.tolist()
            temp_preds = list()
            for i, p in enumerate(preds):
                temp_preds.append(preds[i][:dec_lengths[i]])  # remove pads
            preds = temp_preds
            hypotheses.extend(preds)

            assert len(references) == len(hypotheses)

        ## Compute BLEU-4 Scores
        #recent_bleu4 = corpus_bleu(references, hypotheses, emulate_multibleu=True)
        recent_bleu4 = corpus_bleu(references, hypotheses)

        print(
            '\n * LOSS - {loss.avg:.3f}, TOP-5 ACCURACY - {top5.avg:.3f}, BLEU-4 - {bleu}\n'
            .format(loss=losses, top5=top5accs, bleu=recent_bleu4))

        # CHECK IMPROVEMENT
        is_best = recent_bleu4 > best_bleu4
        best_bleu4 = max(recent_bleu4, best_bleu4)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" %
                  (epochs_since_improvement))
        else:
            epochs_since_improvement = 0

        # SAVE CHECKPOINT
        save_checkpoint(dataset_name, epoch, epochs_since_improvement, encoder,
                        decoder, encoder_optimizer, decoder_optimizer,
                        recent_bleu4, is_best)
        print("Epoch {}, cost time: {}\n".format(epoch + 1,
                                                 now_time - total_start_time))
def main():

    print('Training parameters Initialized')
    training_parameters = TrainingParameters( start_epoch = 0,
                                            epochs = 120,  # number of epochs to train for
                                            epochs_since_improvement = 0,  # Epochs since improvement in BLEU score
                                            batch_size = 32,
                                            workers = 1,  # for data-loading; right now, only 1 works with h5py
                                            fine_tune_encoder = True,  # fine-tune encoder
                                            encoder_lr = 1e-4,  # learning rate for encoder, if fine-tuning is used
                                            decoder_lr = 4e-4,  # learning rate for decoder
                                            grad_clip = 5.0,  # clip gradients at an absolute value of
                                            alpha_c = 1.0,  # regularization parameter for 'doubly stochastic attention'
                                            best_bleu4 = 0.0,  # BLEU-4 score right now
                                            print_freq = 100,  # print training/validation stats every __ batches
                                            checkpoint =  './Result/BEST_checkpoint_flickr8k_5_captions_per_image_5_minimum_word_frequency.pth.tar' # path to checkpoint, None if none
                                            # checkpoint = None
                                          )

    print('Loading Word-Map')
    word_map_file = os.path.join(data_folder,'WORDMAP_' + data_name + '.json')
    with open(word_map_file, 'r') as j:
        word_map = json.load(j)

    print('Creating Model')

    if training_parameters.checkpoint is None:
        encoder = Encoder()
        encoder.fine_tune(training_parameters.fine_tune_encoder)
        encoder_optimizer = torch.optim.Adam(params=filter(lambda p : p.requires_grad, encoder.parameters()),
                                                lr=training_parameters.encoder_lr) if training_parameters.fine_tune_encoder else None
        
        decoder = Decoder(attention_dimension = attention_dimension,
                            embedding_dimension = embedding_dimension,
                            hidden_dimension = hidden_dimension,
                            vocab_size = len(word_map),
                            device = device,
                            dropout = dropout)                            
        decoder_optimizer = torch.optim.Adam(params=filter(lambda p : p.requires_grad, decoder.parameters()),
                                                lr=training_parameters.decoder_lr)

    else:
        checkpoint = torch.load(training_parameters.checkpoint)
        training_parameters.start_epoch = checkpoint['epoch'] + 1
        training_parameters.epochs_since_improvement = checkpoint['epochs_since_improvement']
        training_parameters.best_bleu4 = checkpoint['bleu4']

        encoder = Encoder()
        encoder.load_state_dict(checkpoint['encoder_state_dict'])
        encoder_optimizer = checkpoint['encoder_optimizer']

        decoder = Decoder(attention_dimension = attention_dimension,
                            embedding_dimension = embedding_dimension,
                            hidden_dimension = hidden_dimension,
                            vocab_size = len(word_map),
                            device = device,
                            dropout = dropout)
        decoder.load_state_dict(checkpoint['decoder_state_dict'])
        decoder_optimizer = checkpoint['decoder_optimizer']

        if training_parameters.fine_tune_encoder is True and encoder_optimizer is None:
            encoder.fine_tune(training_parameters.fine_tune_encoder)
        encoder_optimizer = torch.optim.Adam(params=filter(lambda p : p.requires_grad, encoder.parameters()),
                                                lr=training_parameters.encoder_lr)

    encoder.to(device)
    decoder.to(device)

    criterion = nn.CrossEntropyLoss().to(device)
        
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    print('Creating Data Loaders')
    train_dataloader = torch.utils.data.DataLoader(
                                    CaptionDataset(data_folder, data_name, 'TRAIN', transform=transforms.Compose([normalize])),
                                    batch_size=training_parameters.batch_size, shuffle=True)
    
    validation_dataloader = torch.utils.data.DataLoader(
                                    CaptionDataset(data_folder, data_name, 'VALID', transform=transforms.Compose([normalize])),
                                    batch_size=training_parameters.batch_size, shuffle=True, pin_memory=True)

    for epoch in range(training_parameters.start_epoch, training_parameters.epochs):

        if training_parameters.epochs_since_improvement == 20:
            break
        if training_parameters.epochs_since_improvement > 0  and training_parameters.epochs_since_improvement % 8 == 0:
            adjust_learning_rate(decoder_optimizer, 0.8)
            if training_parameters.fine_tune_encoder:
                adjust_learning_rate(encoder_optimizer, 0.8)

        train(train_loader = train_dataloader,
              encoder = encoder,
              decoder = decoder,
              criterion = criterion,
              encoder_optimizer = encoder_optimizer,
              decoder_optimizer = decoder_optimizer,
              epoch = epoch,
              device = device,
              training_parameters = training_parameters)

        recent_bleu4_score = validate(validation_loader = validation_dataloader,
                                    encoder = encoder,
                                    decoder = decoder,
                                    criterion = criterion,
                                    word_map = word_map,
                                    device = device,
                                    training_parameters = training_parameters)

        is_best_score = recent_bleu4_score > training_parameters.best_bleu4
        training_parameters.best_bleu4 = max(recent_bleu4_score, training_parameters.best_bleu4)
        if not is_best_score:
            training_parameters.epochs_since_improvement += 1
            print('\nEpochs since last improvement : %d\n' % (training_parameters.epochs_since_improvement))
        else:
            training_parameters.epochs_since_improvement = 0
        
        save_checkpoint(data_name, epoch, training_parameters.epochs_since_improvement, encoder, decoder,
                        encoder_optimizer, decoder_optimizer, recent_bleu4_score, is_best_score)
def main():
    """
    Training and validation.
    """

    global best_bleu4, epochs_since_improvement, checkpoint, start_epoch, fine_tune_encoder, data_name, word_map

    # Read word map
    word_map_file = os.path.join(data_folder, "WORDMAP_" + data_name + ".json")
    with open(word_map_file, "r") as j:
        word_map = json.load(j)

    # Initialize / load checkpoint
    if checkpoint is None:
        decoder = DecoderWithAttention(
            attention_dim=attention_dim,
            embed_dim=emb_dim,
            decoder_dim=decoder_dim,
            vocab_size=len(word_map),
            dropout=dropout,
        )
        decoder_optimizer = torch.optim.Adam(
            params=filter(lambda p: p.requires_grad, decoder.parameters()),
            lr=decoder_lr,
        )
        encoder = Encoder()
        encoder.fine_tune(fine_tune_encoder)
        encoder_optimizer = (torch.optim.Adam(
            params=filter(lambda p: p.requires_grad, encoder.parameters()),
            lr=encoder_lr,
        ) if fine_tune_encoder else None)

    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint["epoch"] + 1
        epochs_since_improvement = checkpoint["epochs_since_improvement"]
        best_bleu4 = checkpoint["bleu-4"]
        decoder = checkpoint["decoder"]
        decoder_optimizer = checkpoint["decoder_optimizer"]
        encoder = checkpoint["encoder"]
        encoder_optimizer = checkpoint["encoder_optimizer"]
        if fine_tune_encoder is True and encoder_optimizer is None:
            encoder.fine_tune(fine_tune_encoder)
            encoder_optimizer = torch.optim.Adam(
                params=filter(lambda p: p.requires_grad, encoder.parameters()),
                lr=encoder_lr,
            )

    # Move to GPU, if available
    decoder = decoder.to(device)
    encoder = encoder.to(device)

    # Loss function
    criterion = nn.CrossEntropyLoss().to(device)

    # Custom dataloaders
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_loader = torch.utils.data.DataLoader(
        CaptionDataset(data_folder,
                       data_name,
                       "TRAIN",
                       transform=transforms.Compose([normalize])),
        batch_size=batch_size,
        shuffle=True,
        num_workers=workers,
        pin_memory=True,
    )
    val_loader = torch.utils.data.DataLoader(
        CaptionDataset(data_folder,
                       data_name,
                       "VAL",
                       transform=transforms.Compose([normalize])),
        batch_size=batch_size,
        shuffle=True,
        num_workers=workers,
        pin_memory=True,
    )

    # Epochs
    for epoch in range(start_epoch, epochs):

        # Decay learning rate if there is no improvement for 8 consecutive epochs, and terminate training after 20
        if epochs_since_improvement == 20:
            break
        if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:
            adjust_learning_rate(decoder_optimizer, 0.8)
            if fine_tune_encoder:
                adjust_learning_rate(encoder_optimizer, 0.8)

        # One epoch's training
        train(
            train_loader=train_loader,
            encoder=encoder,
            decoder=decoder,
            criterion=criterion,
            encoder_optimizer=encoder_optimizer,
            decoder_optimizer=decoder_optimizer,
            epoch=epoch,
        )

        # One epoch's validation
        recent_bleu4 = validate(val_loader=val_loader,
                                encoder=encoder,
                                decoder=decoder,
                                criterion=criterion)

        # Check if there was an improvement
        is_best = recent_bleu4 > best_bleu4
        best_bleu4 = max(recent_bleu4, best_bleu4)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" %
                  (epochs_since_improvement, ))
        else:
            epochs_since_improvement = 0

        # Save checkpoint
        save_checkpoint(
            data_name,
            epoch,
            epochs_since_improvement,
            encoder,
            decoder,
            encoder_optimizer,
            decoder_optimizer,
            recent_bleu4,
            is_best,
        )
def main():
    """
    Training and validation.
    """

    global epochs_since_improvement, checkpoint, start_epoch, fine_tune_encoder, data_name, word_map, role_map
    #print('reading word map')
    # Read word map
    word_map_file = os.path.join(data_folder, 'token2id' + '.json')
    with open(word_map_file, 'r') as j:
        word_map = json.load(j)
    #print('reading role map')
    role_map_file = os.path.join(data_folder, 'roles2id' + '.json')
    with open(role_map_file, 'r') as j:
        role_map = json.load(j)
    #print('initializing..')
    # Initialize / load checkpoint
    if checkpoint is None:
        decoder = DecoderWithAttention(attention_dim=attention_dim,
                                       embed_dim=emb_dim,
                                       decoder_dim=decoder_dim,
                                       vocab_size=len(word_map),
                                       role_vocab_size=len(role_map),
                                       role_embed_dim=role_dim,
                                       dropout=dropout)
        decoder_optimizer = torch.optim.Adam(params=filter(
            lambda p: p.requires_grad, decoder.parameters()),
                                             lr=decoder_lr)
        encoder = Encoder()
        encoder.fine_tune(fine_tune_encoder)
        encoder_optimizer = torch.optim.Adam(
            params=filter(lambda p: p.requires_grad, encoder.parameters()),
            lr=encoder_lr) if fine_tune_encoder else None

    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        #best_bleu4 = checkpoint['bleu-4']
        decoder = checkpoint['decoder']
        decoder_optimizer = checkpoint['decoder_optimizer']
        encoder = checkpoint['encoder']
        encoder_optimizer = checkpoint['encoder_optimizer']
        if fine_tune_encoder is True and encoder_optimizer is None:
            encoder.fine_tune(fine_tune_encoder)
            encoder_optimizer = torch.optim.Adam(params=filter(
                lambda p: p.requires_grad, encoder.parameters()),
                                                 lr=encoder_lr)

    # Move to GPU, if available
    decoder = decoder.to(device)
    encoder = encoder.to(device)
    #print('creating encoder/decoder..')
    #encoder = nn.DataParallel(encoder,device_ids=[0,1])
    #decoder = nn.DataParallel(decoder,device_ids=[0,1])
    # Loss function
    criterion = nn.CrossEntropyLoss().to(device)

    #print('creating dataloader..')
    # Custom dataloaders
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_loader = torch.utils.data.DataLoader(FrameDataset(
        data_folder, 'TRAIN', transform=transforms.Compose([normalize])),
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=workers,
                                               pin_memory=True)

    # val_loader = torch.utils.data.DataLoader(
    #     FrameDataset(data_folder, 'VAL', transform=transforms.Compose([normalize])),
    #     batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)

    # Epochs
    for epoch in range(start_epoch, epochs):

        # decay learning rate somehow
        # One epoch's training
        #print('start training')
        train(train_loader=train_loader,
              encoder=encoder,
              decoder=decoder,
              criterion=criterion,
              encoder_optimizer=encoder_optimizer,
              decoder_optimizer=decoder_optimizer,
              epoch=epoch)
        print('start validation..')
Exemple #18
0
def train(args):
    cfg_from_file(args.cfg)
    cfg.WORKERS = args.num_workers
    pprint.pprint(cfg)
    # set the seed manually
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    # define outputer
    outputer_train = Outputer(args.output_dir, cfg.IMAGETEXT.PRINT_EVERY,
                              cfg.IMAGETEXT.SAVE_EVERY)
    outputer_val = Outputer(args.output_dir, cfg.IMAGETEXT.PRINT_EVERY,
                            cfg.IMAGETEXT.SAVE_EVERY)
    # define the dataset
    split_dir, bshuffle = 'train', True

    # Get data loader
    imsize = cfg.TREE.BASE_SIZE * (2**(cfg.TREE.BRANCH_NUM - 1))
    train_transform = transforms.Compose([
        transforms.Scale(int(imsize * 76 / 64)),
        transforms.RandomCrop(imsize),
    ])
    val_transform = transforms.Compose([
        transforms.Scale(int(imsize * 76 / 64)),
        transforms.CenterCrop(imsize),
    ])
    if args.dataset == 'bird':
        train_dataset = ImageTextDataset(args.data_dir,
                                         split_dir,
                                         transform=train_transform,
                                         sample_type='train')
        val_dataset = ImageTextDataset(args.data_dir,
                                       'val',
                                       transform=val_transform,
                                       sample_type='val')
    elif args.dataset == 'coco':
        train_dataset = CaptionDataset(args.data_dir,
                                       split_dir,
                                       transform=train_transform,
                                       sample_type='train',
                                       coco_data_json=args.coco_data_json)
        val_dataset = CaptionDataset(args.data_dir,
                                     'val',
                                     transform=val_transform,
                                     sample_type='val',
                                     coco_data_json=args.coco_data_json)
    else:
        raise NotImplementedError

    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=cfg.IMAGETEXT.BATCH_SIZE,
        shuffle=bshuffle,
        num_workers=int(cfg.WORKERS))
    val_dataloader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=cfg.IMAGETEXT.BATCH_SIZE,
        shuffle=False,
        num_workers=1)
    # define the model and optimizer
    if args.raw_checkpoint != '':
        encoder, decoder = load_raw_checkpoint(args.raw_checkpoint)
    else:
        encoder = Encoder()
        decoder = DecoderWithAttention(
            attention_dim=cfg.IMAGETEXT.ATTENTION_DIM,
            embed_dim=cfg.IMAGETEXT.EMBED_DIM,
            decoder_dim=cfg.IMAGETEXT.DECODER_DIM,
            vocab_size=train_dataset.n_words)
        # load checkpoint
        if cfg.IMAGETEXT.CHECKPOINT != '':
            outputer_val.log("load model from: {}".format(
                cfg.IMAGETEXT.CHECKPOINT))
            encoder, decoder = load_checkpoint(encoder, decoder,
                                               cfg.IMAGETEXT.CHECKPOINT)

    encoder.fine_tune(False)
    # to cuda
    encoder = encoder.cuda()
    decoder = decoder.cuda()
    loss_func = torch.nn.CrossEntropyLoss()
    if args.eval:  # eval only
        outputer_val.log("only eval the model...")
        assert cfg.IMAGETEXT.CHECKPOINT != ''
        val_rtn_dict, outputer_val = validate_one_epoch(
            0, val_dataloader, encoder, decoder, loss_func, outputer_val)
        outputer_val.log("\n[valid]: {}\n".format(dict2str(val_rtn_dict)))
        return

    # define optimizer
    optimizer_encoder = torch.optim.Adam(encoder.parameters(),
                                         lr=cfg.IMAGETEXT.ENCODER_LR)
    optimizer_decoder = torch.optim.Adam(decoder.parameters(),
                                         lr=cfg.IMAGETEXT.DECODER_LR)
    encoder_lr_scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer_encoder, step_size=10, gamma=cfg.IMAGETEXT.LR_GAMMA)
    decoder_lr_scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer_decoder, step_size=10, gamma=cfg.IMAGETEXT.LR_GAMMA)
    print("train the model...")
    for epoch_idx in range(cfg.IMAGETEXT.EPOCH):
        # val_rtn_dict, outputer_val = validate_one_epoch(epoch_idx, val_dataloader, encoder,
        #         decoder, loss_func, outputer_val)
        # outputer_val.log("\n[valid] epoch: {}, {}".format(epoch_idx, dict2str(val_rtn_dict)))
        train_rtn_dict, outputer_train = train_one_epoch(
            epoch_idx, train_dataloader, encoder, decoder, optimizer_encoder,
            optimizer_decoder, loss_func, outputer_train)
        # adjust lr scheduler
        encoder_lr_scheduler.step()
        decoder_lr_scheduler.step()

        outputer_train.log("\n[train] epoch: {}, {}\n".format(
            epoch_idx, dict2str(train_rtn_dict)))
        val_rtn_dict, outputer_val = validate_one_epoch(
            epoch_idx, val_dataloader, encoder, decoder, loss_func,
            outputer_val)
        outputer_val.log("\n[valid] epoch: {}, {}\n".format(
            epoch_idx, dict2str(val_rtn_dict)))

        outputer_val.save_step({
            "encoder": encoder.state_dict(),
            "decoder": decoder.state_dict()
        })
    outputer_val.save({
        "encoder": encoder.state_dict(),
        "decoder": decoder.state_dict()
    })
Exemple #19
0
def main():
    """
    训练和验证
    """

    global best_bleu4, epochs_since_improvement, checkpoint, start_epoch, fine_tune_encoder, data_name, word_map

    # 读入词典
    word_map_file = os.path.join(data_folder, 'WORDMAP_' + data_name + '.json')
    with open(word_map_file, 'r') as j:
        word_map = json.load(j)

    # 初始化/加载模型
    if checkpoint is None:
        decoder = DecoderWithAttention(hidden_size=hidden_size,
                                       vocab_size=len(word_map),
                                       attention_dim=attention_dim,
                                       embed_size=emb_dim,
                                       dropout=dropout)
        decoder_optimizer = torch.optim.Adam(params=decoder.parameters(),
                                             lr=decoder_lr,
                                             betas=(0.8, 0.999))
        encoder = Encoder(hidden_size=hidden_size,
                          embed_size=emb_dim,
                          dropout=dropout)
        # 是否微调
        encoder.fine_tune(fine_tune_encoder)
        encoder_optimizer = torch.optim.Adam(
            params=filter(lambda p: p.requires_grad, encoder.parameters()),
            lr=encoder_lr,
            betas=(0.8, 0.999)) if fine_tune_encoder else None

    else:
        #载入checkpoint
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        best_bleu4 = checkpoint['bleu-4']
        decoder = checkpoint['decoder']
        decoder_optimizer = checkpoint['decoder_optimizer']
        encoder = checkpoint['encoder']
        encoder_optimizer = checkpoint['encoder_optimizer']
        if fine_tune_encoder is True and encoder_optimizer is None:
            # 如果此时要开始微调,需要定义优化器
            encoder.fine_tune(fine_tune_encoder)
            encoder_optimizer = torch.optim.Adam(params=filter(
                lambda p: p.requires_grad, encoder.parameters()),
                                                 lr=encoder_lr,
                                                 betas=(0.8, 0.999))

    # 移动到GPU
    decoder = decoder.to(device)
    encoder = encoder.to(device)

    # Loss function
    criterion = nn.CrossEntropyLoss().to(device)

    # dataloaders
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])  #ImageNet
    # pin_memory = True 驻留内存,不换进换出
    train_loader = torch.utils.data.DataLoader(CaptionDataset(
        data_folder,
        data_name,
        'TRAIN',
        transform=transforms.Compose([normalize])),
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(CaptionDataset(
        data_folder,
        data_name,
        'VAL',
        transform=transforms.Compose([normalize])),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=workers,
                                             pin_memory=True)

    # Epochs
    for epoch in range(start_epoch, epochs):

        if epoch > 15:
            adjust_learning_rate(decoder_optimizer, epoch)
            if fine_tune_encoder:
                adjust_learning_rate(encoder_optimizer, epoch)

        # Early Stopping if the validation score does not imporive for 6 consecutive epochs
        if epochs_since_improvement == 6:
            break

        # 一个epoch的训练
        train(train_loader=train_loader,
              encoder=encoder,
              decoder=decoder,
              criterion=criterion,
              encoder_optimizer=encoder_optimizer,
              decoder_optimizer=decoder_optimizer,
              epoch=epoch,
              vocab_size=len(word_map))

        # 一个epoch的验证
        recent_bleu4 = validate(val_loader=val_loader,
                                encoder=encoder,
                                decoder=decoder,
                                criterion=criterion)

        # 检查是否有提升
        is_best = recent_bleu4 > best_bleu4
        best_bleu4 = max(recent_bleu4, best_bleu4)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" %
                  (epochs_since_improvement, ))
        else:
            epochs_since_improvement = 0

        # 保存模型
        save_checkpoint(data_name, epoch, epochs_since_improvement, encoder,
                        decoder, encoder_optimizer, decoder_optimizer,
                        recent_bleu4, is_best)
Exemple #20
0
def main():
    global checkpoint, start_epoch, fine_tune_encoder, word_map_structure, word_map_cell, epochs_since_improvement, hyper_loss, id2word_stucture, id2word_cell, teds, best_TED

    if checkpoint is None:
        decoder_structure = DecoderStuctureWithAttention(
            attention_dim=attention_dim,
            embed_dim=emb_dim_structure,
            decoder_dim=decoder_dim_structure,
            vocab=word_map_structure,
            dropout=dropout)
        decoder_cell = DecoderCellPerImageWithAttention(
            attention_dim=attention_dim,
            embed_dim=emb_dim_cell,
            decoder_dim=decoder_dim_cell,
            vocab_size=len(word_map_cell),
            dropout=0.2,
            decoder_structure_dim=decoder_dim_structure)
        decoder_structure_optimizer = torch.optim.Adam(params=filter(
            lambda p: p.requires_grad, decoder_structure.parameters()),
                                                       lr=decoder_lr)
        decoder_cell_optimizer = torch.optim.Adam(params=filter(
            lambda p: p.requires_grad, decoder_structure.parameters()),
                                                  lr=decoder_lr)

        encoder = Encoder()
        encoder.fine_tune(fine_tune_encoder)
        encoder_optimizer = torch.optim.Adam(
            params=filter(lambda p: p.requires_grad, encoder.parameters()),
            lr=encoder_lr) if fine_tune_encoder else None

    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        decoder_structure = checkpoint['decoder_structure']
        decoder_structure_optimizer = checkpoint["decoder_structure_optimizer"]

        decoder_cell = checkpoint["decoder_cell"]
        decoder_cell_optimizer = checkpoint["decoder_cell_optimizer"]

        encoder = checkpoint['encoder']
        encoder_optimizer = checkpoint['encoder_optimizer']
        best_TED = checkpoint['ted_score']

        if fine_tune_encoder is True and encoder_optimizer is None:
            encoder.fine_tune(fine_tune_encoder)
            encoder_optimizer = torch.optim.Adam(params=filter(
                lambda p: p.requires_grad, encoder.parameters()),
                                                 lr=encoder_lr)

    # Move to GPU, if available
    decoder_structure = decoder_structure.to(device)
    decoder_cell = decoder_cell.to(device)
    encoder = encoder.to(device)

    # Loss function
    criterion = nn.CrossEntropyLoss().to(device)

    # Custom dataloaders
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    print("loading train_loader and val_loader:")
    train_loader = torch.utils.data.DataLoader(CaptionDataset(
        data_folder, 'train', transform=transforms.Compose([normalize])),
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=workers,
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(CaptionDataset(
        data_folder, 'val', transform=transforms.Compose([normalize])),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=workers,
                                             pin_memory=True)
    print("Done train_loader and val_loader:")
    # train foreach epoch
    for epoch in range(start_epoch, epochs):
        # Decay learning rate if there is no improvement for 8 consecutive epochs, and terminate training after 20
        if epochs_since_improvement == 20:
            break
        if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:
            adjust_learning_rate(decoder_structure, 0.8)
            adjust_learning_rate(decoder_cell, 0.8)
            if fine_tune_encoder:
                adjust_learning_rate(encoder_optimizer, 0.8)
        print("Starting train..............")
        train(train_loader=train_loader,
              encoder=encoder,
              decoder_structure=decoder_structure,
              decoder_cell=decoder_cell,
              criterion_structure=criterion,
              criterion_cell=criterion,
              encoder_optimizer=encoder_optimizer,
              decoder_structure_optimizer=decoder_structure_optimizer,
              decoder_cell_optimizer=decoder_cell_optimizer,
              epoch=epoch)
        print("Starting validation..............")
        recent_ted_score = val(val_loader=val_loader,
                               encoder=encoder,
                               decoder_structure=decoder_structure,
                               decoder_cell=decoder_cell,
                               criterion_structure=criterion,
                               criterion_cell=criterion)

        # Check if there was an improvement
        is_best = recent_ted_score > best_TED
        best_TED = max(recent_ted_score, best_TED)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" %
                  (epochs_since_improvement, ))
        else:
            epochs_since_improvement = 0

        # save checkpoint
        save_checkpoint(epoch, epochs_since_improvement, encoder,
                        decoder_structure, decoder_cell, encoder_optimizer,
                        decoder_structure_optimizer, decoder_cell_optimizer,
                        recent_ted_score, is_best)
def main():
    """
    Training and validation.
    """

    global checkpoint, start_epoch, fine_tune_encoder

    # Initialize / load checkpoint
    if checkpoint is None:

        encoder = Encoder()
        print(encoder)
        encoder.fine_tune(fine_tune_encoder)
        encoder_optimizer = torch.optim.Adam(
            params=filter(lambda p: p.requires_grad, encoder.parameters()),
            lr=encoder_lr) if fine_tune_encoder else None

        decoder = DecoderWithAttention(attention_dim=attention_dim,
                                       embed_dim=emb_dim,
                                       decoder_dim=decoder_dim,
                                       vocab_size=vocab_size,
                                       encoder_dim=encoder_dim,
                                       dropout=dropout)
        decoder_optimizer = torch.optim.Adam(params=filter(
            lambda p: p.requires_grad, decoder.parameters()),
                                             lr=decoder_lr)

    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        decoder = checkpoint['decoder']
        decoder_optimizer = checkpoint['decoder_optimizer']
        encoder = checkpoint['encoder']
        encoder_optimizer = checkpoint['encoder_optimizer']
        if fine_tune_encoder is True and encoder_optimizer is None:
            encoder.fine_tune(fine_tune_encoder)
            encoder_optimizer = torch.optim.Adam(params=filter(
                lambda p: p.requires_grad, encoder.parameters()),
                                                 lr=encoder_lr)

    # Move to GPU, if available
    decoder = decoder.to(device)
    encoder = encoder.to(device)

    # Loss function
    criterion = nn.CrossEntropyLoss().to(device)

    # customized dataloader
    MyDataset = DualLoadDatasets(imgsz, txt_folder, img_folder, bin_folder,
                                 split, Gfiltersz, Gblursigma)
    #drop the last batch since it is not divisible by batchsize
    train_loader = torch.utils.data.DataLoader(MyDataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=workers,
                                               pin_memory=True,
                                               drop_last=True)

    #    val_loader = torch.utils.data.DataLoader(
    #        CaptionDataset(data_folder, data_name, 'VAL', transform=transforms.Compose([normalize])),
    #        batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)

    # Save checkpoint
    epoch = 0
    save_checkpoint(epoch, encoder, decoder, encoder_optimizer,
                    decoder_optimizer)
    print('saving models to models/checkpoint')

    # Epochs
    for epoch in range(start_epoch, epochs):
        #print(image_transforms)

        # One epoch's training
        train(train_loader=train_loader,
              encoder=encoder,
              decoder=decoder,
              transform=transform,
              criterion=criterion,
              encoder_optimizer=encoder_optimizer,
              decoder_optimizer=decoder_optimizer,
              epoch=epoch)

        # Save checkpoint
        save_checkpoint(epoch, encoder, decoder, encoder_optimizer,
                        decoder_optimizer)
        print('saving models to models/checkpoint')
Exemple #22
0
def main():
    """
    Training and validation.
    """

    global best_bleu4, epochs_since_improvement, checkpoint, start_epoch, fine_tune_encoder, data_name, word_map

    # Read word map
    word_map_file = os.path.join(data_folder, 'WORDMAP_' + data_name + '.json')
    with open(word_map_file, 'r') as j:
        word_map = json.load(j)

    # Initialize / load checkpoint
    if checkpoint is None:

        emb_dim=100 #remove if not usiong pretrained model
        decoder = DecoderWithAttention(attention_dim=attention_dim,
                                       embed_dim=emb_dim,
                                       decoder_dim=decoder_dim,
                                       vocab_size=len(word_map),
                                       dropout=dropout)
        pretrained_embeddings = decoder.create_pretrained_embedding_matrix(word_map)
        decoder.load_pretrained_embeddings(
            pretrained_embeddings)  # pretrained_embeddings should be of dimensions (len(word_map), emb_dim)
        decoder.fine_tune_embeddings(True)

        decoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, decoder.parameters()),
                                             lr=decoder_lr)
        encoder = Encoder()
        encoder.fine_tune(fine_tune_encoder)
        encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()),
                                             lr=encoder_lr) if fine_tune_encoder else None

    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        best_bleu4 = checkpoint['bleu-4']
        decoder = checkpoint['decoder']
        decoder_optimizer = checkpoint['decoder_optimizer']
        encoder = checkpoint['encoder']
        encoder_optimizer = checkpoint['encoder_optimizer']
        if fine_tune_encoder is True and encoder_optimizer is None:
            encoder.fine_tune(fine_tune_encoder)
            encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()),
                                                 lr=encoder_lr)

    # Move to GPU, if available
    decoder = decoder.to(device)
    encoder = encoder.to(device)

    # Loss function
    criterion = nn.CrossEntropyLoss().to(device)

    # Custom dataloaders
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_loader = torch.utils.data.DataLoader(
        CaptionDataset(data_folder, data_name, 'TRAIN', transform=transforms.Compose([normalize])),
        batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)
    val_loader = torch.utils.data.DataLoader(
        CaptionDataset(data_folder, data_name, 'VAL', transform=transforms.Compose([normalize])),
        batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)

    # Epochs
    for epoch in range(start_epoch, epochs):

        # Decay learning rate if there is no improvement for 8 consecutive epochs, and terminate training after 20
        if epochs_since_improvement == 20:
            break
        if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:
            adjust_learning_rate(decoder_optimizer, 0.8)
            if fine_tune_encoder:
                adjust_learning_rate(encoder_optimizer, 0.8)

        # One epoch's training
        train(train_loader=train_loader,
              encoder=encoder,
              decoder=decoder,
              criterion=criterion,
              encoder_optimizer=encoder_optimizer,
              decoder_optimizer=decoder_optimizer,
              epoch=epoch)

        # One epoch's validation
        recent_bleu4, val_loss_avg, val_accu_avg = validate(val_loader=val_loader,
                                encoder=encoder,
                                decoder=decoder,
                                criterion=criterion)
        #write to tensorboard
        writer.add_scalar('validation_loss', val_loss_avg, epoch)
        writer.add_scalar('validation_accuracy', val_accu_avg, epoch)
        writer.add_scalar('validation_bleu4', recent_bleu4, epoch)

        # Check if there was an improvement
        is_best = recent_bleu4 > best_bleu4
        best_bleu4 = max(recent_bleu4, best_bleu4)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement,))
        else:
            epochs_since_improvement = 0

        # Save checkpoint
        print("Saving model to file",ckpt_name.format(epoch, bleu=recent_bleu4, loss=val_loss_avg, acc=val_accu_avg))
        save_checkpoint(ckpt_name.format(epoch, bleu=recent_bleu4, loss=val_loss_avg, acc=val_accu_avg), 
                        epoch, epochs_since_improvement, encoder, decoder, encoder_optimizer,
                        decoder_optimizer, recent_bleu4, is_best)

    #close tensorboard writer
    writer.close()
Exemple #23
0
def main():
    """
    Training and validation.
    """

    global best_bleu4, use_amp, epochs_since_improvement, checkpoint, start_epoch, fine_tune_encoder, data_name, word_map

    #use_amp = True
    #print("Using amp for mized precision training")

    # Read word map
    word_map_file = os.path.join(data_folder, 'WORDMAP_' + data_name + '.json')
    with open(word_map_file, 'r') as j:
        word_map = json.load(j)

    # Initialize / load checkpoint
    if checkpoint is None:
        decoder = DecoderWithAttention(attention_dim=attention_dim,
                                       embed_dim=emb_dim,
                                       decoder_dim=decoder_dim,
                                       vocab_size=len(word_map),
                                       dropout=dropout)
        decoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, decoder.parameters()),
                                             lr=decoder_lr)
        encoder = Encoder()
        encoder.fine_tune(fine_tune_encoder)
        encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()),
                                             lr=encoder_lr) if fine_tune_encoder else None

    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        best_bleu4 = checkpoint['bleu-4']
        decoder = checkpoint['decoder']
        decoder_optimizer = checkpoint['decoder_optimizer']
        encoder = checkpoint['encoder']
        encoder_optimizer = checkpoint['encoder_optimizer']
        if fine_tune_encoder is True and encoder_optimizer is None:
            encoder.fine_tune(fine_tune_encoder)
            encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()),
                                                 lr=encoder_lr)

    # Move to GPU, if available
    decoder = decoder.to(device)
    # use mixed precision training using Nvidia Apex
    if use_amp:
        decoder, decoder_optimizer = amp.initialize(
                                    decoder, decoder_optimizer, opt_level="O2",
                                    keep_batchnorm_fp32=True, loss_scale="dynamic")
    encoder = encoder.to(device)
    if not encoder_optimizer:
        print("Encoder is not being optimized")
    elif use_amp:
        encoder, encoder_optimizer = amp.initialize(
                                    encoder, encoder_optimizer, opt_level="O2",
                                    keep_batchnorm_fp32=True, loss_scale="dynamic")

    # Loss function
    criterion = nn.CrossEntropyLoss().to(device)

    # Custom dataloaders
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_loader = torch.utils.data.DataLoader(
        CaptionDataset(data_folder, data_name, 'TRAIN', transform=transforms.Compose([normalize])),
        batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)
    val_loader = torch.utils.data.DataLoader(
        CaptionDataset(data_folder, data_name, 'VAL', transform=transforms.Compose([normalize])),
        batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)

    # Epochs
    for epoch in range(start_epoch, epochs):

        # Decay learning rate if there is no improvement for 8 consecutive epochs, and terminate training after 20
        if epochs_since_improvement == 20:
            break
        if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:
            adjust_learning_rate(decoder_optimizer, 0.8)
            if fine_tune_encoder:
                adjust_learning_rate(encoder_optimizer, 0.8)

        # One epoch's training
        train(train_loader=train_loader,
              encoder=encoder,
              decoder=decoder,
              criterion=criterion,
              encoder_optimizer=encoder_optimizer,
              decoder_optimizer=decoder_optimizer,
              epoch=epoch)

        # One epoch's validation
        recent_bleu4 = validate(val_loader=val_loader,
                                encoder=encoder,
                                decoder=decoder,
                                criterion=criterion)

        # Check if there was an improvement
        is_best = recent_bleu4 > best_bleu4
        best_bleu4 = max(recent_bleu4, best_bleu4)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement,))
        else:
            epochs_since_improvement = 0

        # Save checkpoint
        save_checkpoint(data_name, epoch, epochs_since_improvement, encoder, decoder, encoder_optimizer,
                        decoder_optimizer, recent_bleu4, is_best)
def main():
    """
    Training and validation.
    """

    global best_bleu4, epochs_since_improvement, checkpoint, start_epoch, fine_tune_encoder, data_name, word_map

    # Read word map
    word_map_file = os.path.join(data_folder, 'WORDMAP_' + data_name + '.json')
    with open(word_map_file, 'r') as j:
        word_map = json.load(j)

    # Initialize / load checkpoint
    decoder = Fine_Tune_DecoderWithAttention(attention_dim=attention_dim,
                                             embed_dim=emb_dim,
                                             decoder_dim=decoder_dim,
                                             vocab_size=len(word_map),
                                             dropout=dropout)

    val_decoder = DecoderWithAttention(attention_dim=attention_dim,
                                       embed_dim=emb_dim,
                                       decoder_dim=decoder_dim,
                                       vocab_size=len(word_map),
                                       dropout=dropout)

    decoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, decoder.parameters()),
                                         lr=decoder_lr)
    encoder = Encoder()
    encoder.fine_tune(fine_tune_encoder)
    encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()),
                                         lr=encoder_lr) if fine_tune_encoder else None
    g_remover = RemoveGenderRegion()

    if checkpoint is not None:
        if is_cpu:
            checkpoint = torch.load(checkpoint,  map_location='cpu')
        else:
            checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        best_bleu4 = checkpoint['bleu-4']

        decoder = load_parameter(checkpoint['decoder'], decoder)
        encoder = load_parameter(checkpoint['encoder'], encoder)

        # decoder_optimizer = checkpoint['decoder_optimizer']
        decoder_optimizer = load_parameter(checkpoint['decoder_optimizer'], decoder_optimizer)
        #encoder_optimizer = checkpoint['encoder_optimizer']
        if fine_tune_encoder is True and encoder_optimizer is None:
            encoder_optimizer = load_parameter(checkpoint['encoder_optimizer'], encoder_optimizer)

        if fine_tune_encoder is True and encoder_optimizer is None:
            encoder.fine_tune(fine_tune_encoder)
            encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()),
                                                 lr=encoder_lr)
        if freeze_decoder_lstm:
            decoder.freeze_LSTM(freeze=True)

    # Move to GPU, if available
    decoder = decoder.to(device)
    encoder = encoder.to(device)
    g_remover = g_remover.to(device)

    # Loss function
    criterion = nn.CrossEntropyLoss().to(device)

    # Custom dataloaders
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    # fix CUDA bug
    if not is_cpu:
        for state in decoder_optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.cuda()

    '''
    train_loader = torch.utils.data.DataLoader(
        CaptionDataset(data_folder, data_name, 'TRAIN', transform=transforms.Compose([normalize])),
        batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)
    val_loader = torch.utils.data.DataLoader(
        CaptionDataset(data_folder, data_name, 'VAL', transform=transforms.Compose([normalize])),
        batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)
    '''

    if not supervised_training:
        train_loader = torch.utils.data.DataLoader(
            Fine_Tune_CaptionDataset(data_folder, data_name, 'TRAIN', transform=transforms.Compose([normalize])),
            batch_size=batch_size, shuffle=False, num_workers=workers, pin_memory=True)
    else:
        train_loader = torch.utils.data.DataLoader(
            Fine_Tune_CaptionDataset_With_Mask(data_folder, data_name, 'TRAIN', transform=transforms.Compose([normalize])),
            batch_size=batch_size, shuffle=False, num_workers=workers, pin_memory=True)

    val_loader = torch.utils.data.DataLoader(
        CaptionDataset(data_folder, data_name, 'VAL', transform=transforms.Compose([normalize])),
        batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)

    # Epochs
    for epoch in range(start_epoch, epochs):

        # Decay learning rate if there is no improvement for 8 consecutive epochs, and terminate training after 20
        if epochs_since_improvement == 20:
            break
        if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:
            adjust_learning_rate(decoder_optimizer, 0.8)
            if fine_tune_encoder:
                adjust_learning_rate(encoder_optimizer, 0.8)

        if not supervised_training:
            # One epoch's training
            self_guided_fine_tune_train(train_loader=train_loader,
                                        encoder=encoder,
                                        decoder=decoder,
                                        criterion=criterion,
                                        encoder_optimizer=encoder_optimizer,
                                        decoder_optimizer=decoder_optimizer,
                                        g_remover=g_remover,
                                        epoch=epoch)
        else:
            supervised_guided_fine_tune_train(train_loader=train_loader,
                                              encoder=encoder,
                                              decoder=decoder,
                                              criterion=criterion,
                                              encoder_optimizer=encoder_optimizer,
                                              decoder_optimizer=decoder_optimizer,
                                              g_remover=g_remover,
                                              epoch=epoch)

        # One epoch's validation

        val_decoder = load_parameter(decoder, val_decoder)
        val_decoder = val_decoder.to(device)

        recent_bleu4 = validate(val_loader=val_loader,
                                encoder=encoder,
                                decoder=val_decoder,
                                criterion=criterion)

        # Check if there was an improvement
        is_best = recent_bleu4 > best_bleu4
        best_bleu4 = max(recent_bleu4, best_bleu4)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement,))
        else:
            epochs_since_improvement = 0

        # Save checkpoint
        save_checkpoint(data_name, epoch, epochs_since_improvement, encoder, decoder, encoder_optimizer,
                        decoder_optimizer, recent_bleu4, is_best, checkpoint_savepath) 
Exemple #25
0
def main():
    """
    Training and validation.
    """
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--data_folder",
        default='data/',
        type=str,
        help="folder with data files saved by create_input_files.py")
    parser.add_argument("--data_name",
                        default='coco_5_cap_per_img_5_min_word_freq',
                        type=str,
                        help="base name shared by data files")
    parser.add_argument("--output_dir",
                        default='saved_models/',
                        type=str,
                        help="path to save checkpoints")
    parser.add_argument("--checkpoint",
                        default=None,
                        type=str,
                        help="path to checkpoint")
    parser.add_argument("--emb_dim",
                        default=512,
                        type=int,
                        help="dimension of word embeddings")
    parser.add_argument("--attention_dim",
                        default=512,
                        type=int,
                        help="dimension of attention linear layers")
    parser.add_argument("--decoder_dim",
                        default=512,
                        type=int,
                        help="dimension of decoder RNN")
    parser.add_argument("--dropout",
                        default=0.5,
                        type=float,
                        help="dimension of word embeddings")
    parser.add_argument("--start_epoch", default=0, type=int)
    parser.add_argument(
        "--epochs",
        default=120,
        type=int,
        help=
        "number of epochs to train for (if early stopping is not triggered)")
    parser.add_argument("--batch_size",
                        default=128,
                        type=int,
                        help="batch size for training and testing")
    parser.add_argument("--workers",
                        default=8,
                        type=int,
                        help="num of workers for data-loading")
    parser.add_argument("--encoder_lr", default=1e-4, type=float)
    parser.add_argument("--decoder_lr", default=5e-4, type=float)
    parser.add_argument("--grad_clip",
                        default=5,
                        type=float,
                        help="clip gradients at an absolute value of")
    parser.add_argument(
        "--alpha_c",
        default=1,
        type=int,
        help=
        "regularization parameter for 'doubly stochastic attention', as in the paper"
    )
    parser.add_argument(
        "--print_freq",
        default=100,
        type=int,
        help="print training/validation stats every __ batches")
    parser.add_argument("--fine_tune_encoder",
                        action='store_true',
                        help="Whether to finetune the encoder")
    args = parser.parse_args()

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    args.device = device

    best_bleu4 = 0

    epochs_since_improvement = 0

    # Read word map
    word_map_file = os.path.join(args.data_folder,
                                 'WORDMAP_' + args.data_name + '.json')
    with open(word_map_file, 'r') as j:
        word_map = json.load(j)

    # Initialize / load checkpoint
    if args.checkpoint is None:
        decoder = DecoderWithAttention(attention_dim=args.attention_dim,
                                       embed_dim=args.emb_dim,
                                       decoder_dim=args.decoder_dim,
                                       vocab_size=len(word_map),
                                       dropout=args.dropout)
        decoder_optimizer = torch.optim.Adam(params=filter(
            lambda p: p.requires_grad, decoder.parameters()),
                                             lr=args.decoder_lr)
        encoder = Encoder()
        encoder.fine_tune(args.fine_tune_encoder)
        encoder_optimizer = torch.optim.Adam(
            params=filter(lambda p: p.requires_grad, encoder.parameters()),
            lr=args.encoder_lr) if args.fine_tune_encoder else None

    else:
        checkpoint = torch.load(args.checkpoint)
        args.start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        best_bleu4 = checkpoint['bleu-4']
        decoder = checkpoint['decoder']
        decoder_optimizer = checkpoint['decoder_optimizer']
        encoder = checkpoint['encoder']
        encoder_optimizer = checkpoint['encoder_optimizer']
        if args.fine_tune_encoder is True and encoder_optimizer is None:
            encoder.fine_tune(args.fine_tune_encoder)
            encoder_optimizer = torch.optim.Adam(params=filter(
                lambda p: p.requires_grad, encoder.parameters()),
                                                 lr=args.encoder_lr)

    # Move to GPU, if available
    decoder = decoder.to(args.device)
    encoder = encoder.to(args.device)

    # Loss function
    criterion = nn.CrossEntropyLoss(ignore_index=0).to(args.device)

    # Custom dataloaders
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_loader = torch.utils.data.DataLoader(CaptionDataset(
        args.data_folder,
        args.data_name,
        'TRAIN',
        transform=transforms.Compose([normalize])),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    print(f'train dataset length {len(train_loader)}')
    val_loader = torch.utils.data.DataLoader(CaptionDataset(
        args.data_folder,
        args.data_name,
        'VAL',
        transform=transforms.Compose([normalize])),
                                             batch_size=args.batch_size,
                                             shuffle=True,
                                             num_workers=args.workers,
                                             pin_memory=True)
    print(f'val dataset length {len(val_loader)}')

    # Epochs
    for epoch in range(args.start_epoch, args.epochs):

        # Decay learning rate if there is no improvement for 8 consecutive epochs, and terminate training after 20
        if epochs_since_improvement == 20:
            break
        if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:
            adjust_learning_rate(decoder_optimizer, 0.8)
            if args.fine_tune_encoder:
                adjust_learning_rate(encoder_optimizer, 0.8)

        # One epoch's training
        train(train_loader=train_loader,
              encoder=encoder,
              decoder=decoder,
              criterion=criterion,
              encoder_optimizer=encoder_optimizer,
              decoder_optimizer=decoder_optimizer,
              epoch=epoch,
              args=args)

        # One epoch's validation
        recent_bleu4 = validate(val_loader=val_loader,
                                encoder=encoder,
                                decoder=decoder,
                                criterion=criterion,
                                word_map=word_map,
                                args=args)

        # Check if there was an improvement
        is_best = recent_bleu4 > best_bleu4
        best_bleu4 = max(recent_bleu4, best_bleu4)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" %
                  (epochs_since_improvement, ))
        else:
            epochs_since_improvement = 0

        # Save checkpoint
        save_checkpoint(args.data_name, args.output_dir, epoch,
                        epochs_since_improvement, encoder, decoder,
                        encoder_optimizer, decoder_optimizer, recent_bleu4,
                        is_best)
def main():
    """
    Training and validation.
    """

    global best_bleu4, epochs_since_improvement, checkpoint, start_epoch, fine_tune_encoder, data_name, word_map
    best_bleu4 = config.best_bleu4
    epochs_since_improvement = config.epochs_since_improvement
    checkpoint = config.checkpoint
    start_epoch = config.start_epoch
    fine_tune_encoder = config.fine_tune_encoder
    data_name = config.data_name
    checkpoint = config.checkpoint

    log_f = open(config.train_log_path, 'a+', encoding='utf-8')

    # Read word map
    word_map_file = os.path.join(config.data_folder,
                                 'WORDMAP_' + data_name + '.json')
    with open(word_map_file, 'r') as j:
        word_map = json.load(j)

    # Initialize / load checkpoint
    if checkpoint is None:
        print('no checkpoint, rebuild')
        log_f.write('\n\nno checkpoint, rebuild' + '\n')
        decoder = DecoderWithAttention(attention_dim=config.attention_dim,
                                       embed_dim=config.emb_dim,
                                       decoder_dim=config.decoder_dim,
                                       vocab_size=len(word_map),
                                       dropout=config.dropout)
        decoder_optimizer = torch.optim.Adam(params=filter(
            lambda p: p.requires_grad, decoder.parameters()),
                                             lr=config.decoder_lr)
        encoder = Encoder()
        encoder.fine_tune(fine_tune_encoder)
        encoder_optimizer = torch.optim.Adam(
            params=filter(lambda p: p.requires_grad, encoder.parameters()),
            lr=config.encoder_lr) if fine_tune_encoder else None

    else:
        print('checkpoint exist,continue.. \n{}'.format(checkpoint))
        log_f.write('\n\ncheckpoint exist,continue.. \n{}'.format(checkpoint) +
                    '\n')
        log_f.close()
        checkpoint = torch.load(
            checkpoint,
            map_location=config.device)  # map_location=device for cpu
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        best_bleu4 = checkpoint['bleu-4']
        decoder = checkpoint['decoder']
        decoder_optimizer = checkpoint['decoder_optimizer']
        encoder = checkpoint['encoder']
        encoder_optimizer = checkpoint['encoder_optimizer']
        if fine_tune_encoder is True and encoder_optimizer is None:  # check for fine tuning
            encoder.fine_tune(
                fine_tune_encoder)  # change requires_grad for weights
            encoder_optimizer = torch.optim.Adam(params=filter(
                lambda p: p.requires_grad, encoder.parameters()),
                                                 lr=config.encoder_lr)

    # Move to GPU, if available
    # decoder = decoder.to(config.device)   # no GPU
    # encoder = encoder.to(config.device)   # no GPU

    # Loss function
    # criterion = nn.CrossEntropyLoss().to(config.device)   # no GPU
    criterion = nn.CrossEntropyLoss()

    # Custom batch dataloaders
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406],  # 这里是原ResNet的mean和std
        std=[0.229, 0.224, 0.225])
    train_loader = torch.utils.data.DataLoader(
        CaptionDataset(config.data_folder,
                       data_name,
                       'TRAIN',
                       transform=transforms.Compose(
                           [normalize])),  # CaptionDataset is in datasets.py
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=config.workers,
        pin_memory=True)
    val_loader = torch.utils.data.DataLoader(CaptionDataset(
        config.data_folder,
        data_name,
        'VAL',
        transform=transforms.Compose([normalize])),
                                             batch_size=config.batch_size,
                                             shuffle=True,
                                             num_workers=config.workers,
                                             pin_memory=True)

    # Epochs
    val_writer = SummaryWriter(
        log_dir=config.tensorboard_path + '/val/' +
        time.strftime('%m-%d_%H%M', time.localtime()))  # for tensorboard
    for epoch in range(start_epoch, config.epochs):

        # Decay learning rate if there is no improvement for 8 consecutive epochs, and terminate training after 20
        if epochs_since_improvement == 20:
            break
        if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:
            adjust_learning_rate(decoder_optimizer, 0.8)  # utils.py
            if fine_tune_encoder:
                adjust_learning_rate(encoder_optimizer, 0.8)

        # One epoch's training
        train(train_loader=train_loader,
              encoder=encoder,
              decoder=decoder,
              criterion=criterion,
              encoder_optimizer=encoder_optimizer,
              decoder_optimizer=decoder_optimizer,
              epoch=epoch)

        # One epoch's validation
        recent_bleu4 = validate(val_loader=val_loader,
                                encoder=encoder,
                                decoder=decoder,
                                criterion=criterion,
                                writer=val_writer,
                                epoch=epoch)

        # Check if there was an improvement, check each epoch
        is_best = recent_bleu4 > best_bleu4
        best_bleu4 = max(recent_bleu4, best_bleu4)
        log_f = open(config.train_log_path, 'a+', encoding='utf-8')
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" %
                  (epochs_since_improvement, ))
            log_f.write("\nEpochs since last improvement: %d\n" %
                        (epochs_since_improvement, ) + '\n')
        else:
            epochs_since_improvement = 0
            log_f.write('\n')
        log_f.close()
        # Save checkpoint
        save_checkpoint(data_name, epoch, epochs_since_improvement, encoder,
                        decoder, encoder_optimizer, decoder_optimizer,
                        recent_bleu4, is_best)

    val_writer.close()
def main(checkpoint, tienet):
    """
    Training and validation.
    """

    global best_bleu4, epochs_since_improvement, start_epoch, fine_tune_encoder, data_name, word_map

    if checkpoint:
        dest_dir = checkpoint
        checkpoint = os.path.join(
            dest_dir,
            'checkpoint_mimiccxr_1_cap_per_img_5_min_word_freq.pth.tar'
        )  # path to checkpoint, None if none
    else:
        dest_dir = os.path.join(
            '/data/medg/misc/liuguanx/TieNet/models',
            datetime.datetime.now().strftime('%Y-%m-%d-%H%M%S-%f'))
        os.makedirs(dest_dir)
        checkpoint = None
    # Read word map
    word_map_file = os.path.join(data_folder, 'WORDMAP_' + data_name + '.json')
    with open(word_map_file, 'r') as j:
        word_map = json.load(j)

    # Initialize / load checkpoint
    if checkpoint is None:
        decoder = DecoderWithAttention(attention_dim=attention_dim,
                                       embed_dim=emb_dim,
                                       decoder_dim=decoder_dim,
                                       vocab_size=len(word_map),
                                       dropout=dropout)
        decoder_optimizer = torch.optim.Adam(params=filter(
            lambda p: p.requires_grad, decoder.parameters()),
                                             lr=decoder_lr)
        encoder = Encoder()
        encoder.fine_tune(fine_tune_encoder)
        encoder_optimizer = torch.optim.Adam(
            params=filter(lambda p: p.requires_grad, encoder.parameters()),
            lr=encoder_lr) if fine_tune_encoder else None
        if (tienet):
            jointlearner = JointLearning(num_global_att=num_global_att,
                                         s=s,
                                         decoder_dim=decoder_dim,
                                         label_size=label_size)
            jointlearner_optimizer = torch.optim.Adam(params=filter(
                lambda p: p.requires_grad, jointlearner.parameters()),
                                                      lr=jointlearning_lr)
        else:
            jointlearner = None
            jointlearner_optimizer = None

    else:
        checkpoint = torch.load(checkpoint)
        print('checkpoint loaded')
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        best_bleu4 = checkpoint['best_bleu']
        decoder = checkpoint['decoder']
        decoder_optimizer = checkpoint['decoder_optimizer']
        encoder = checkpoint['encoder']
        encoder_optimizer = checkpoint['encoder_optimizer']
        jointlearner = checkpoint['jointlearner']
        jointlearner_optimizer = checkpoint['jointlearner_optimizer']
        if fine_tune_encoder is True and encoder_optimizer is None:
            encoder.fine_tune(fine_tune_encoder)
            encoder_optimizer = torch.optim.Adam(params=filter(
                lambda p: p.requires_grad, encoder.parameters()),
                                                 lr=encoder_lr)

    # Move to GPU, if available
    if torch.cuda.device_count() > 1:
        print('Using', torch.cuda.device_count(), 'GPUs')
        # decoder = nn.DataParallel(decoder)
        encoder = nn.DataParallel(encoder, device_ids=[1])
        if tienet:
            jointlearner = nn.DataParallel(jointlearner, device_ids=[1])
    decoder = decoder.to(device)
    encoder = encoder.to(device)
    if tienet:
        jointlearner = jointlearner.to(device)

    # Loss function
    criterion_R = nn.CrossEntropyLoss().to(device)
    criterion_C = nn.BCEWithLogitsLoss().to(device)

    # Custom dataloaders
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_loader = torch.utils.data.DataLoader(CaptionDataset(
        data_folder,
        data_name,
        'TRAIN',
        transform=transforms.Compose([normalize])),
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(CaptionDataset(
        data_folder,
        data_name,
        'VAL',
        transform=transforms.Compose([normalize])),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=workers,
                                             pin_memory=True)

    # Epochs
    for epoch in range(start_epoch, epochs):

        # Decay learning rate if there is no improvement for 8 consecutive epochs, and terminate training after 20
        if epochs_since_improvement == 20:
            break
        if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:
            adjust_learning_rate(decoder_optimizer, 0.8)
            if tienet:
                adjust_learning_rate(jointlearner_optimizer, 0.8)
            if fine_tune_encoder:
                adjust_learning_rate(encoder_optimizer, 0.8)

        # One epoch's training
        train(train_loader=train_loader,
              encoder=encoder,
              decoder=decoder,
              jointlearner=jointlearner,
              criterion_R=criterion_R,
              criterion_C=criterion_C,
              encoder_optimizer=encoder_optimizer,
              decoder_optimizer=decoder_optimizer,
              jointlearner_optimizer=jointlearner_optimizer,
              epoch=epoch,
              dest_dir=dest_dir,
              tienet=tienet)

        # One epoch's validation
        recent_bleu4 = validate(val_loader=val_loader,
                                encoder=encoder,
                                decoder=decoder,
                                jointlearner=jointlearner,
                                criterion_R=criterion_R,
                                criterion_C=criterion_C,
                                tienet=tienet)

        # Check if there was an improvement
        is_best = recent_bleu4 > best_bleu4
        best_bleu4 = max(recent_bleu4, best_bleu4)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" %
                  (epochs_since_improvement, ))
        else:
            epochs_since_improvement = 0

        # Save checkpoint
        save_checkpoint(data_name, epoch, epochs_since_improvement, encoder,
                        decoder, jointlearner, encoder_optimizer,
                        decoder_optimizer, jointlearner_optimizer,
                        recent_bleu4, best_bleu4, is_best, dest_dir)
Exemple #28
0
def fit(t_params, checkpoint=None, m_params=None, logger=None):

    # info
    data_name = t_params['data_name']
    imgs_path = t_params['imgs_path']
    df_path = t_params['df_path']
    vocab = t_params['vocab']

    start_epoch = 0
    epochs_since_improvement = 0
    best_bleu4 = 0
    epochs = t_params['epochs']
    batch_size = t_params['batch_size']
    workers = t_params['workers']
    encoder_lr = t_params['encoder_lr']
    decoder_lr = t_params['decoder_lr']
    fine_tune_encoder = t_params['fine_tune_encoder']

    # pretrained word embeddings
    pretrained_embeddings = t_params['pretrained_embeddings']
    if pretrained_embeddings:
        fine_tune_embeddings = t_params['fine_tune_embeddings']
        embeddings_matrix = m_params['embeddings_matrix']

    # init / load checkpoint
    if checkpoint is None:

        # getting hyperparameters
        attention_dim = m_params['attention_dim']
        embed_dim = m_params['embed_dim']
        decoder_dim = m_params['decoder_dim']
        encoder_dim = m_params['encoder_dim']
        dropout = m_params['dropout']

        decoder = DecoderWithAttention(attention_dim=attention_dim,
                                       embed_dim=embed_dim,
                                       decoder_dim=decoder_dim,
                                       encoder_dim=encoder_dim,
                                       vocab_size=len(vocab),
                                       dropout=dropout)

        if pretrained_embeddings:
            decoder.load_pretrained_embeddings(
                torch.tensor(embeddings_matrix, dtype=torch.float32))
            decoder.fine_tune_embeddings(fine_tune=fine_tune_embeddings)

        decoder_optimizer = torch.optim.RMSprop(params=filter(
            lambda p: p.requires_grad, decoder.parameters()),
                                                lr=decoder_lr)

        encoder = Encoder()
        encoder.fine_tune(fine_tune_encoder)
        encoder_optimizer = torch.optim.RMSprop(
            params=filter(lambda p: p.requires_grad, encoder.parameters()),
            lr=encoder_lr) if fine_tune_encoder else None

    # load checkpoint
    else:
        checkpoint = torch.load(checkpoint)
        print('Loaded Checkpoint!!')
        start_epoch = checkpoint['epoch'] + 1
        print(f"Starting Epoch: {start_epoch}")
        epochs_since_improvement = checkpoint['epochs_since_imrovment']
        best_bleu4 = checkpoint['bleu-4']
        decoder = checkpoint['decoder']
        decoder_optimizer = checkpoint['deocder_optimizer']
        encoder = checkpoint['encoder']
        encoder_optimizer = checkpoint['encoder_optimizer']
        if fine_tune_encoder is True and encoder_optimizer is None:
            encoder.fine_tune(fine_tune_encoder)
            encoder_optimizer = torch.optim.RMSprop(params=filter(
                lambda p: p.requires_grad, encoder.parameters()),
                                                    lr=encoder_lr)

    # Schedulers
    decoder_scheduler = ReduceLROnPlateau(decoder_optimizer,
                                          patience=2,
                                          verbose=True)
    if fine_tune_encoder:
        encoder_scheduler = ReduceLROnPlateau(encoder_optimizer,
                                              patience=2,
                                              verbose=True)

    # move to gpu, if available
    decoder = decoder.to(device)
    encoder = encoder.to(device)

    # loss function
    criterion = nn.CrossEntropyLoss().to(device)

    # dataloaders
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])

    print('Loading Data')
    train_loader, val_loader = get_loaders(batch_size, imgs_path, df_path,
                                           transform, vocab, False, workers)
    print('_' * 50)

    print('-' * 20, 'Fitting', '-' * 20)
    for epoch in range(start_epoch, epochs):

        # if epochs_since_improvement > 0 and epochs_since_improvement % 2 == 0:
        #     adjust_learning_rate(decoder_optimizer, 0.8)
        #     if fine_tune_encoder:
        #         adjust_learning_rate(encoder_optimizer, 0.8)

        print('_' * 50)
        print('-' * 20, 'Training', '-' * 20)
        # one epoch of training
        epoch_time = AverageMeter()
        start_time = time.time()
        train(train_loader=train_loader,
              encoder=encoder,
              decoder=decoder,
              criterion=criterion,
              encoder_optimizer=encoder_optimizer,
              decoder_optimizer=decoder_optimizer,
              epoch=epoch,
              logger=logger)
        epoch_time.update(time.time() - start_time)
        print(f"Epoch train time {epoch_time.val:.3f} (epoch_time.avg:.3f)")

        # one epoch of validation
        epoch_time = AverageMeter()
        start_time = time.time()
        print('-' * 20, 'Validation', '-' * 20)
        b1, b2, b3, recent_bleu4 = validate(val_loader=val_loader,
                                            encoder=encoder,
                                            decoder=decoder,
                                            criterion=criterion,
                                            vocab=vocab,
                                            epoch=epoch,
                                            logger=logger)
        epoch_time.update(time.time() - start_time)
        # tensorboard
        logger.add_scalar(f'b-1/valid', b1, epoch)
        logger.add_scalar(f'b-2/valid', b2, epoch)
        logger.add_scalar(f'b-3/valid', b3, epoch)
        logger.add_scalar(f'b-4/valid', recent_bleu4, epoch)
        # logger.add_scalar(f'Meteor/valid', m, epoch)
        print(
            f"Epoch validation time {epoch_time.val:.3f} (epoch_time.avg:.3f)")

        # check for improvement
        is_best = recent_bleu4 > best_bleu4
        best_bleu4 = max(recent_bleu4, best_bleu4)
        if not is_best:
            epochs_since_improvement += 1
            print(
                f'\nEpochs since last improvement: {epochs_since_improvement,}'
            )
        else:
            # reset
            epochs_since_improvement = 0

        # stop training if no improvement for 5 epochs
        if epochs_since_improvement == 5:
            print('No improvement for 5 consecutive epochs, terminating...')
            break

        # learning rate schedular
        decoder_scheduler.step(recent_bleu4)
        if fine_tune_encoder:
            encoder_scheduler.step(recent_bleu4)

        save_checkpoint(data_name, epoch, epochs_since_improvement, encoder,
                        decoder, encoder_optimizer, decoder_optimizer,
                        recent_bleu4, is_best)
Exemple #29
0
encoder_lr = 5e-4  # learning rate for encoder if fine-tuning
decoder_lr = 5e-4  # learning rate for decoder
grad_clip = 5.  # clip gradients at an absolute value of
alpha_c = 1.  # regularization parameter for 'doubly stochastic attention', as in the paper
fine_tune_encoder = True  # fine-tune encoder?
decoder = DecoderWithAttention(attention_dim=attention_dim,
                               embed_dim=emb_dim,
                               decoder_dim=decoder_dim,
                               vocab_size=len(vocab),
                               dropout=dropout)
decoder.fine_tune_embeddings(True)
decoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad,
                                                   decoder.parameters()),
                                     lr=decoder_lr)
encoder = Encoder()
encoder.fine_tune(fine_tune_encoder)
encoder_optimizer = torch.optim.Adam(
    params=filter(lambda p: p.requires_grad, encoder.parameters()),
    lr=encoder_lr) if fine_tune_encoder else None

decoder_sched = torch.optim.lr_scheduler.CosineAnnealingLR(decoder_optimizer,
                                                           5,
                                                           eta_min=1e-5,
                                                           last_epoch=-1)
encoder_sched = torch.optim.lr_scheduler.CosineAnnealingLR(encoder_optimizer,
                                                           5,
                                                           eta_min=1e-5,
                                                           last_epoch=-1)
encoder = encoder.cuda()
decoder = decoder.cuda()
Exemple #30
0
def main():
    """
    Training and validation.
    """

    global best_bleu4, epochs_since_improvement, checkpoint, start_epoch, fine_tune_encoder, data_name, word_map

    # Read word map
    word_map_file = os.path.join(data_folder, 'WORDMAP_' + data_name + '.json')
    with open(word_map_file, 'r') as j:
        word_map = json.load(j)

    # Initialize / load checkpoint
    if checkpoint is None:
        # resnet
        encoder = Encoder(model_name="resnet")
        encoder_dim = 2048

        # squeezenet
        # encoder = Encoder(model_name="squeezenet")
        # encoder_dim = 1000

        # vgg
        # encoder = Encoder(model_name="vgg")
        # encoder_dim = 512

        # # mobileNet
        # encoder = Encoder(model_name="mobileNet")
        # encoder_dim = 1024

        encoder.fine_tune(fine_tune_encoder)
        encoder_optimizer = torch.optim.Adam(
            params=filter(lambda p: p.requires_grad, encoder.parameters()),
            lr=encoder_lr) if fine_tune_encoder else None

        decoder = DecoderWithAttention(attention_dim=attention_dim,
                                       embed_dim=emb_dim,
                                       decoder_dim=decoder_dim,
                                       vocab_size=len(word_map),
                                       dropout=dropout,
                                       encoder_dim=encoder_dim)

        decoder_optimizer = torch.optim.Adam(params=filter(
            lambda p: p.requires_grad, decoder.parameters()),
                                             lr=decoder_lr)

    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        best_bleu4 = checkpoint['bleu-4']
        decoder = checkpoint['decoder']
        decoder_optimizer = checkpoint['decoder_optimizer']
        encoder = checkpoint['encoder']
        encoder_optimizer = checkpoint['encoder_optimizer']
        if fine_tune_encoder is True and encoder_optimizer is None:
            encoder.fine_tune(fine_tune_encoder)
            encoder_optimizer = torch.optim.Adam(params=filter(
                lambda p: p.requires_grad, encoder.parameters()),
                                                 lr=encoder_lr)

    # Move to GPU, if available
    decoder = decoder.to(device)
    encoder = encoder.to(device)

    # Loss function
    criterion = nn.CrossEntropyLoss().to(device)

    # Custom dataloaders
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_loader = torch.utils.data.DataLoader(CaptionDataset(
        data_folder,
        data_name,
        'TRAIN',
        transform=transforms.Compose([normalize])),
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(CaptionDataset(
        data_folder,
        data_name,
        'VAL',
        transform=transforms.Compose([normalize])),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=workers,
                                             pin_memory=True)

    # Epochs
    for epoch in range(start_epoch, epochs):

        # Decay learning rate if there is no improvement for 8 consecutive epochs, and terminate training after 20
        if epochs_since_improvement == 20:
            break
        if epochs_since_improvement > 0 and epochs_since_improvement % 2 == 0:
            adjust_learning_rate(decoder_optimizer, 0.8)
            if fine_tune_encoder:
                adjust_learning_rate(encoder_optimizer, 0.8)
        # start fine-tuning after bleu4 reaches 23, so break this loop
        if best_bleu4 >= 23:
            break

        count_parameters(encoder)
        count_parameters(decoder)

        # One epoch's training
        train(train_loader=train_loader,
              encoder=encoder,
              decoder=decoder,
              criterion=criterion,
              encoder_optimizer=encoder_optimizer,
              decoder_optimizer=decoder_optimizer,
              epoch=epoch)

        # One epoch's validation
        recent_bleu4 = validate(val_loader=val_loader,
                                encoder=encoder,
                                decoder=decoder,
                                criterion=criterion)

        # Check if there was an improvement
        is_best = recent_bleu4 > best_bleu4
        best_bleu4 = max(recent_bleu4, best_bleu4)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" %
                  (epochs_since_improvement, ))
        else:
            epochs_since_improvement = 0

        # Save checkpoint
        save_checkpoint(data_name, epoch, epochs_since_improvement, encoder,
                        decoder, encoder_optimizer, decoder_optimizer,
                        recent_bleu4, is_best)