Ejemplo n.º 1
0
def GetResnet101Features():

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    data_folder = 'C:/Users/paoca/Documents/UVA PHD/NLP/PROJECT/UnnecesaryDataFolder'  # folder with data files saved by create_input_files.py
    data_name = 'coco_5_cap_per_img_5_min_word_freq'

    word_map_file = os.path.join(data_folder, 'WORDMAP_' + data_name + '.json')
    with open(word_map_file, 'r') as j:
        word_map = json.load(j)

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_loader = torch.utils.data.DataLoader(CaptionDataset(
        data_folder,
        data_name,
        'TRAIN',
        transform=transforms.Compose([normalize])),
                                               batch_size=5,
                                               shuffle=False,
                                               pin_memory=True)

    with torch.no_grad():
        encoder = Encoder()
        encoder.fine_tune(False)

        emb_dim = 512
        decoder_dim = 512
        encoderVae_encoder = EncodeVAE_Encoder(embed_dim=emb_dim,
                                               decoder_dim=decoder_dim,
                                               vocab_size=len(word_map))
        encoderVae_encoder.fine_tune(False)

        encoder.eval()
        encoderVae_encoder.eval()

        encoder = encoder.to(device)
        encoderVae_encoder = encoderVae_encoder.to(device)

        for i, (imgs, caps, caplens) in enumerate(train_loader):
            if i % 100 == 0:
                print(i)

            imgs = imgs.to(device)
            caps = caps.to(device)
            caplens = caplens.to(device)

            res = encoder(imgs)
            h = encoderVae_encoder(imgs, caps, caplens)

            pickle.dump(
                res[0].cpu().numpy(),
                open(
                    "C:/Users/paoca/Documents/UVA PHD/NLP/PROJECT/UnnecesaryDataFolder/TrainResnet101Features/"
                    + str(i) + ".p", "wb"))
            pickle.dump(
                h[0].cpu().numpy(),
                open(
                    "C:/Users/paoca/Documents/UVA PHD/NLP/PROJECT/UnnecesaryDataFolder/TrainResnet101Features/VAE_"
                    + str(i) + ".p", "wb"))
def main():
    """
    Training and validation.
    """

    global best_bleu4, epochs_since_improvement, checkpoint, start_epoch, fine_tune_encoder, data_name, word_map

    # Read word map
    word_map_file = os.path.join(data_folder, 'WORDMAP_' + data_name + '.json')
    with open(word_map_file, 'r') as j:
        word_map = json.load(j)

    # Initialize / load checkpoint
    decoder = Fine_Tune_DecoderWithAttention(attention_dim=attention_dim,
                                             embed_dim=emb_dim,
                                             decoder_dim=decoder_dim,
                                             vocab_size=len(word_map),
                                             dropout=dropout)

    val_decoder = DecoderWithAttention(attention_dim=attention_dim,
                                       embed_dim=emb_dim,
                                       decoder_dim=decoder_dim,
                                       vocab_size=len(word_map),
                                       dropout=dropout)

    decoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, decoder.parameters()),
                                         lr=decoder_lr)
    encoder = Encoder()
    encoder.fine_tune(fine_tune_encoder)
    encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()),
                                         lr=encoder_lr) if fine_tune_encoder else None
    g_remover = RemoveGenderRegion()

    if checkpoint is not None:
        if is_cpu:
            checkpoint = torch.load(checkpoint,  map_location='cpu')
        else:
            checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        best_bleu4 = checkpoint['bleu-4']

        decoder = load_parameter(checkpoint['decoder'], decoder)
        encoder = load_parameter(checkpoint['encoder'], encoder)

        # decoder_optimizer = checkpoint['decoder_optimizer']
        decoder_optimizer = load_parameter(checkpoint['decoder_optimizer'], decoder_optimizer)
        #encoder_optimizer = checkpoint['encoder_optimizer']
        if fine_tune_encoder is True and encoder_optimizer is None:
            encoder_optimizer = load_parameter(checkpoint['encoder_optimizer'], encoder_optimizer)

        if fine_tune_encoder is True and encoder_optimizer is None:
            encoder.fine_tune(fine_tune_encoder)
            encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()),
                                                 lr=encoder_lr)
        if freeze_decoder_lstm:
            decoder.freeze_LSTM(freeze=True)

    # Move to GPU, if available
    decoder = decoder.to(device)
    encoder = encoder.to(device)
    g_remover = g_remover.to(device)

    # Loss function
    criterion = nn.CrossEntropyLoss().to(device)

    # Custom dataloaders
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    # fix CUDA bug
    if not is_cpu:
        for state in decoder_optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.cuda()

    '''
    train_loader = torch.utils.data.DataLoader(
        CaptionDataset(data_folder, data_name, 'TRAIN', transform=transforms.Compose([normalize])),
        batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)
    val_loader = torch.utils.data.DataLoader(
        CaptionDataset(data_folder, data_name, 'VAL', transform=transforms.Compose([normalize])),
        batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)
    '''

    if not supervised_training:
        train_loader = torch.utils.data.DataLoader(
            Fine_Tune_CaptionDataset(data_folder, data_name, 'TRAIN', transform=transforms.Compose([normalize])),
            batch_size=batch_size, shuffle=False, num_workers=workers, pin_memory=True)
    else:
        train_loader = torch.utils.data.DataLoader(
            Fine_Tune_CaptionDataset_With_Mask(data_folder, data_name, 'TRAIN', transform=transforms.Compose([normalize])),
            batch_size=batch_size, shuffle=False, num_workers=workers, pin_memory=True)

    val_loader = torch.utils.data.DataLoader(
        CaptionDataset(data_folder, data_name, 'VAL', transform=transforms.Compose([normalize])),
        batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)

    # Epochs
    for epoch in range(start_epoch, epochs):

        # Decay learning rate if there is no improvement for 8 consecutive epochs, and terminate training after 20
        if epochs_since_improvement == 20:
            break
        if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:
            adjust_learning_rate(decoder_optimizer, 0.8)
            if fine_tune_encoder:
                adjust_learning_rate(encoder_optimizer, 0.8)

        if not supervised_training:
            # One epoch's training
            self_guided_fine_tune_train(train_loader=train_loader,
                                        encoder=encoder,
                                        decoder=decoder,
                                        criterion=criterion,
                                        encoder_optimizer=encoder_optimizer,
                                        decoder_optimizer=decoder_optimizer,
                                        g_remover=g_remover,
                                        epoch=epoch)
        else:
            supervised_guided_fine_tune_train(train_loader=train_loader,
                                              encoder=encoder,
                                              decoder=decoder,
                                              criterion=criterion,
                                              encoder_optimizer=encoder_optimizer,
                                              decoder_optimizer=decoder_optimizer,
                                              g_remover=g_remover,
                                              epoch=epoch)

        # One epoch's validation

        val_decoder = load_parameter(decoder, val_decoder)
        val_decoder = val_decoder.to(device)

        recent_bleu4 = validate(val_loader=val_loader,
                                encoder=encoder,
                                decoder=val_decoder,
                                criterion=criterion)

        # Check if there was an improvement
        is_best = recent_bleu4 > best_bleu4
        best_bleu4 = max(recent_bleu4, best_bleu4)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement,))
        else:
            epochs_since_improvement = 0

        # Save checkpoint
        save_checkpoint(data_name, epoch, epochs_since_improvement, encoder, decoder, encoder_optimizer,
                        decoder_optimizer, recent_bleu4, is_best, checkpoint_savepath) 
Ejemplo n.º 3
0
def main():
    """
    Training and validation.
    """

    global best_bleu4, epochs_since_improvement, checkpoint, start_epoch, fine_tune_encoder, data_name, word_map

    # Read word map
    word_map_file = os.path.join(data_folder, 'WORDMAP_' + data_name + '.json')
    with open(word_map_file, 'r') as j:
        word_map = json.load(j)

    # Initialize / load checkpoint
    if checkpoint is None:

        emb_dim=100 #remove if not usiong pretrained model
        decoder = DecoderWithAttention(attention_dim=attention_dim,
                                       embed_dim=emb_dim,
                                       decoder_dim=decoder_dim,
                                       vocab_size=len(word_map),
                                       dropout=dropout)
        pretrained_embeddings = decoder.create_pretrained_embedding_matrix(word_map)
        decoder.load_pretrained_embeddings(
            pretrained_embeddings)  # pretrained_embeddings should be of dimensions (len(word_map), emb_dim)
        decoder.fine_tune_embeddings(True)

        decoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, decoder.parameters()),
                                             lr=decoder_lr)
        encoder = Encoder()
        encoder.fine_tune(fine_tune_encoder)
        encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()),
                                             lr=encoder_lr) if fine_tune_encoder else None

    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        best_bleu4 = checkpoint['bleu-4']
        decoder = checkpoint['decoder']
        decoder_optimizer = checkpoint['decoder_optimizer']
        encoder = checkpoint['encoder']
        encoder_optimizer = checkpoint['encoder_optimizer']
        if fine_tune_encoder is True and encoder_optimizer is None:
            encoder.fine_tune(fine_tune_encoder)
            encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()),
                                                 lr=encoder_lr)

    # Move to GPU, if available
    decoder = decoder.to(device)
    encoder = encoder.to(device)

    # Loss function
    criterion = nn.CrossEntropyLoss().to(device)

    # Custom dataloaders
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_loader = torch.utils.data.DataLoader(
        CaptionDataset(data_folder, data_name, 'TRAIN', transform=transforms.Compose([normalize])),
        batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)
    val_loader = torch.utils.data.DataLoader(
        CaptionDataset(data_folder, data_name, 'VAL', transform=transforms.Compose([normalize])),
        batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)

    # Epochs
    for epoch in range(start_epoch, epochs):

        # Decay learning rate if there is no improvement for 8 consecutive epochs, and terminate training after 20
        if epochs_since_improvement == 20:
            break
        if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:
            adjust_learning_rate(decoder_optimizer, 0.8)
            if fine_tune_encoder:
                adjust_learning_rate(encoder_optimizer, 0.8)

        # One epoch's training
        train(train_loader=train_loader,
              encoder=encoder,
              decoder=decoder,
              criterion=criterion,
              encoder_optimizer=encoder_optimizer,
              decoder_optimizer=decoder_optimizer,
              epoch=epoch)

        # One epoch's validation
        recent_bleu4, val_loss_avg, val_accu_avg = validate(val_loader=val_loader,
                                encoder=encoder,
                                decoder=decoder,
                                criterion=criterion)
        #write to tensorboard
        writer.add_scalar('validation_loss', val_loss_avg, epoch)
        writer.add_scalar('validation_accuracy', val_accu_avg, epoch)
        writer.add_scalar('validation_bleu4', recent_bleu4, epoch)

        # Check if there was an improvement
        is_best = recent_bleu4 > best_bleu4
        best_bleu4 = max(recent_bleu4, best_bleu4)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement,))
        else:
            epochs_since_improvement = 0

        # Save checkpoint
        print("Saving model to file",ckpt_name.format(epoch, bleu=recent_bleu4, loss=val_loss_avg, acc=val_accu_avg))
        save_checkpoint(ckpt_name.format(epoch, bleu=recent_bleu4, loss=val_loss_avg, acc=val_accu_avg), 
                        epoch, epochs_since_improvement, encoder, decoder, encoder_optimizer,
                        decoder_optimizer, recent_bleu4, is_best)

    #close tensorboard writer
    writer.close()
Ejemplo n.º 4
0
def main():
    global checkpoint, start_epoch, fine_tune_encoder, word_map_structure, word_map_cell, epochs_since_improvement, hyper_loss, id2word_stucture, id2word_cell, teds, best_TED

    if checkpoint is None:
        decoder_structure = DecoderStuctureWithAttention(
            attention_dim=attention_dim,
            embed_dim=emb_dim_structure,
            decoder_dim=decoder_dim_structure,
            vocab=word_map_structure,
            dropout=dropout)
        decoder_cell = DecoderCellPerImageWithAttention(
            attention_dim=attention_dim,
            embed_dim=emb_dim_cell,
            decoder_dim=decoder_dim_cell,
            vocab_size=len(word_map_cell),
            dropout=0.2,
            decoder_structure_dim=decoder_dim_structure)
        decoder_structure_optimizer = torch.optim.Adam(params=filter(
            lambda p: p.requires_grad, decoder_structure.parameters()),
                                                       lr=decoder_lr)
        decoder_cell_optimizer = torch.optim.Adam(params=filter(
            lambda p: p.requires_grad, decoder_structure.parameters()),
                                                  lr=decoder_lr)

        encoder = Encoder()
        encoder.fine_tune(fine_tune_encoder)
        encoder_optimizer = torch.optim.Adam(
            params=filter(lambda p: p.requires_grad, encoder.parameters()),
            lr=encoder_lr) if fine_tune_encoder else None

    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        decoder_structure = checkpoint['decoder_structure']
        decoder_structure_optimizer = checkpoint["decoder_structure_optimizer"]

        decoder_cell = checkpoint["decoder_cell"]
        decoder_cell_optimizer = checkpoint["decoder_cell_optimizer"]

        encoder = checkpoint['encoder']
        encoder_optimizer = checkpoint['encoder_optimizer']
        best_TED = checkpoint['ted_score']

        if fine_tune_encoder is True and encoder_optimizer is None:
            encoder.fine_tune(fine_tune_encoder)
            encoder_optimizer = torch.optim.Adam(params=filter(
                lambda p: p.requires_grad, encoder.parameters()),
                                                 lr=encoder_lr)

    # Move to GPU, if available
    decoder_structure = decoder_structure.to(device)
    decoder_cell = decoder_cell.to(device)
    encoder = encoder.to(device)

    # Loss function
    criterion = nn.CrossEntropyLoss().to(device)

    # Custom dataloaders
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    print("loading train_loader and val_loader:")
    train_loader = torch.utils.data.DataLoader(CaptionDataset(
        data_folder, 'train', transform=transforms.Compose([normalize])),
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=workers,
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(CaptionDataset(
        data_folder, 'val', transform=transforms.Compose([normalize])),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=workers,
                                             pin_memory=True)
    print("Done train_loader and val_loader:")
    # train foreach epoch
    for epoch in range(start_epoch, epochs):
        # Decay learning rate if there is no improvement for 8 consecutive epochs, and terminate training after 20
        if epochs_since_improvement == 20:
            break
        if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:
            adjust_learning_rate(decoder_structure, 0.8)
            adjust_learning_rate(decoder_cell, 0.8)
            if fine_tune_encoder:
                adjust_learning_rate(encoder_optimizer, 0.8)
        print("Starting train..............")
        train(train_loader=train_loader,
              encoder=encoder,
              decoder_structure=decoder_structure,
              decoder_cell=decoder_cell,
              criterion_structure=criterion,
              criterion_cell=criterion,
              encoder_optimizer=encoder_optimizer,
              decoder_structure_optimizer=decoder_structure_optimizer,
              decoder_cell_optimizer=decoder_cell_optimizer,
              epoch=epoch)
        print("Starting validation..............")
        recent_ted_score = val(val_loader=val_loader,
                               encoder=encoder,
                               decoder_structure=decoder_structure,
                               decoder_cell=decoder_cell,
                               criterion_structure=criterion,
                               criterion_cell=criterion)

        # Check if there was an improvement
        is_best = recent_ted_score > best_TED
        best_TED = max(recent_ted_score, best_TED)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" %
                  (epochs_since_improvement, ))
        else:
            epochs_since_improvement = 0

        # save checkpoint
        save_checkpoint(epoch, epochs_since_improvement, encoder,
                        decoder_structure, decoder_cell, encoder_optimizer,
                        decoder_structure_optimizer, decoder_cell_optimizer,
                        recent_ted_score, is_best)
Ejemplo n.º 5
0
def main():
    """
    Training and validation.
    """

    global epochs_since_improvement, checkpoint, start_epoch, fine_tune_encoder, data_name, word_map, role_map
    #print('reading word map')
    # Read word map
    word_map_file = os.path.join(data_folder, 'token2id' + '.json')
    with open(word_map_file, 'r') as j:
        word_map = json.load(j)
    #print('reading role map')
    role_map_file = os.path.join(data_folder, 'roles2id' + '.json')
    with open(role_map_file, 'r') as j:
        role_map = json.load(j)
    #print('initializing..')
    # Initialize / load checkpoint
    if checkpoint is None:
        decoder = DecoderWithAttention(attention_dim=attention_dim,
                                       embed_dim=emb_dim,
                                       decoder_dim=decoder_dim,
                                       vocab_size=len(word_map),
                                       role_vocab_size=len(role_map),
                                       role_embed_dim=role_dim,
                                       dropout=dropout)
        decoder_optimizer = torch.optim.Adam(params=filter(
            lambda p: p.requires_grad, decoder.parameters()),
                                             lr=decoder_lr)
        encoder = Encoder()
        encoder.fine_tune(fine_tune_encoder)
        encoder_optimizer = torch.optim.Adam(
            params=filter(lambda p: p.requires_grad, encoder.parameters()),
            lr=encoder_lr) if fine_tune_encoder else None

    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        #best_bleu4 = checkpoint['bleu-4']
        decoder = checkpoint['decoder']
        decoder_optimizer = checkpoint['decoder_optimizer']
        encoder = checkpoint['encoder']
        encoder_optimizer = checkpoint['encoder_optimizer']
        if fine_tune_encoder is True and encoder_optimizer is None:
            encoder.fine_tune(fine_tune_encoder)
            encoder_optimizer = torch.optim.Adam(params=filter(
                lambda p: p.requires_grad, encoder.parameters()),
                                                 lr=encoder_lr)

    # Move to GPU, if available
    decoder = decoder.to(device)
    encoder = encoder.to(device)
    #print('creating encoder/decoder..')
    #encoder = nn.DataParallel(encoder,device_ids=[0,1])
    #decoder = nn.DataParallel(decoder,device_ids=[0,1])
    # Loss function
    criterion = nn.CrossEntropyLoss().to(device)

    #print('creating dataloader..')
    # Custom dataloaders
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_loader = torch.utils.data.DataLoader(FrameDataset(
        data_folder, 'TRAIN', transform=transforms.Compose([normalize])),
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=workers,
                                               pin_memory=True)

    # val_loader = torch.utils.data.DataLoader(
    #     FrameDataset(data_folder, 'VAL', transform=transforms.Compose([normalize])),
    #     batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)

    # Epochs
    for epoch in range(start_epoch, epochs):

        # decay learning rate somehow
        # One epoch's training
        #print('start training')
        train(train_loader=train_loader,
              encoder=encoder,
              decoder=decoder,
              criterion=criterion,
              encoder_optimizer=encoder_optimizer,
              decoder_optimizer=decoder_optimizer,
              epoch=epoch)
        print('start validation..')
Ejemplo n.º 6
0
class Dreamer(Agent):
    # The agent has its own replay buffer, update, act
    def __init__(self, args):
        """
    All paras are passed by args
    :param args: a dict that includes parameters
    """
        super().__init__()
        self.args = args
        # Initialise model parameters randomly
        self.transition_model = TransitionModel(
            args.belief_size, args.state_size, args.action_size,
            args.hidden_size, args.embedding_size,
            args.dense_act).to(device=args.device)

        self.observation_model = ObservationModel(
            args.symbolic,
            args.observation_size,
            args.belief_size,
            args.state_size,
            args.embedding_size,
            activation_function=(args.dense_act if args.symbolic else
                                 args.cnn_act)).to(device=args.device)

        self.reward_model = RewardModel(args.belief_size, args.state_size,
                                        args.hidden_size,
                                        args.dense_act).to(device=args.device)

        self.encoder = Encoder(args.symbolic, args.observation_size,
                               args.embedding_size,
                               args.cnn_act).to(device=args.device)

        self.actor_model = ActorModel(
            args.action_size,
            args.belief_size,
            args.state_size,
            args.hidden_size,
            activation_function=args.dense_act,
            fix_speed=args.fix_speed,
            throttle_base=args.throttle_base).to(device=args.device)

        self.value_model = ValueModel(args.belief_size, args.state_size,
                                      args.hidden_size,
                                      args.dense_act).to(device=args.device)

        self.value_model2 = ValueModel(args.belief_size, args.state_size,
                                       args.hidden_size,
                                       args.dense_act).to(device=args.device)

        self.pcont_model = PCONTModel(args.belief_size, args.state_size,
                                      args.hidden_size,
                                      args.dense_act).to(device=args.device)

        self.target_value_model = deepcopy(self.value_model)
        self.target_value_model2 = deepcopy(self.value_model2)

        for p in self.target_value_model.parameters():
            p.requires_grad = False
        for p in self.target_value_model2.parameters():
            p.requires_grad = False

        # setup the paras to update
        self.world_param = list(self.transition_model.parameters())\
                          + list(self.observation_model.parameters())\
                          + list(self.reward_model.parameters())\
                          + list(self.encoder.parameters())
        if args.pcont:
            self.world_param += list(self.pcont_model.parameters())

        # setup optimizer
        self.world_optimizer = optim.Adam(self.world_param, lr=args.world_lr)
        self.actor_optimizer = optim.Adam(self.actor_model.parameters(),
                                          lr=args.actor_lr)
        self.value_optimizer = optim.Adam(list(self.value_model.parameters()) +
                                          list(self.value_model2.parameters()),
                                          lr=args.value_lr)

        # setup the free_nat to
        self.free_nats = torch.full(
            (1, ), args.free_nats, dtype=torch.float32,
            device=args.device)  # Allowed deviation in KL divergence

        # TODO: change it to the new replay buffer, in buffer.py
        self.D = ExperienceReplay(args.experience_size, args.symbolic,
                                  args.observation_size, args.action_size,
                                  args.bit_depth, args.device)

        if self.args.auto_temp:
            # setup for learning of alpha term (temp of the entropy term)
            self.log_temp = torch.zeros(1,
                                        requires_grad=True,
                                        device=args.device)
            self.target_entropy = -np.prod(
                args.action_size if not args.fix_speed else self.args.
                action_size - 1).item()  # heuristic value from SAC paper
            self.temp_optimizer = optim.Adam(
                [self.log_temp], lr=args.value_lr)  # use the same value_lr

        # TODO: print out the param used in Dreamer
        # var_counts = tuple(count_vars(module) for module in [self., self.ac.q1, self.ac.q2])
        # print('\nNumber of parameters: \t pi: %d, \t q1: %d, \t q2: %d\n' % var_counts)

    # def process_im(self, image, image_size=None, rgb=None):
    #   # Resize, put channel first, convert it to a tensor, centre it to [-0.5, 0.5] and add batch dimenstion.
    #
    #   def preprocess_observation_(observation, bit_depth):
    #     # Preprocesses an observation inplace (from float32 Tensor [0, 255] to [-0.5, 0.5])
    #     observation.div_(2 ** (8 - bit_depth)).floor_().div_(2 ** bit_depth).sub_(
    #       0.5)  # Quantise to given bit depth and centre
    #     observation.add_(torch.rand_like(observation).div_(
    #       2 ** bit_depth))  # Dequantise (to approx. match likelihood of PDF of continuous images vs. PMF of discrete images)
    #
    #   image = image[40:, :, :]  # clip the above 40 rows
    #   image = torch.tensor(cv2.resize(image, (40, 40), interpolation=cv2.INTER_LINEAR).transpose(2, 0, 1),
    #                         dtype=torch.float32)  # Resize and put channel first
    #
    #   preprocess_observation_(image, self.args.bit_depth)
    #   return image.unsqueeze(dim=0)
    def process_im(self, images, image_size=None, rgb=None):
        images = cv2.resize(images, (40, 40))
        images = np.dot(images, [0.299, 0.587, 0.114])
        obs = torch.tensor(images,
                           dtype=torch.float32).div_(255.).sub_(0.5).unsqueeze(
                               dim=0)  # shape [1, 40, 40], range:[-0.5,0.5]
        return obs.unsqueeze(dim=0)  # add batch dimension

    def append_buffer(self, new_traj):
        # append new collected trajectory, not implement the data augmentation
        # shape of new_traj: [(o, a, r, d) * steps]
        for state in new_traj:
            observation, action, reward, done = state
            self.D.append(observation, action.cpu(), reward, done)

    def _compute_loss_world(self, state, data):
        # unpackage data
        beliefs, prior_states, prior_means, prior_std_devs, posterior_states, posterior_means, posterior_std_devs = state
        observations, rewards, nonterminals = data

        # observation_loss = F.mse_loss(
        #   bottle(self.observation_model, (beliefs, posterior_states)),
        #   observations[1:],
        #   reduction='none').sum(dim=2 if self.args.symbolic else (2, 3, 4)).mean(dim=(0, 1))
        #
        # reward_loss = F.mse_loss(
        #   bottle(self.reward_model, (beliefs, posterior_states)),
        #   rewards[1:],
        #   reduction='none').mean(dim=(0,1))

        observation_loss = F.mse_loss(
            bottle(self.observation_model, (beliefs, posterior_states)),
            observations,
            reduction='none').sum(
                dim=2 if self.args.symbolic else (2, 3, 4)).mean(dim=(0, 1))

        reward_loss = F.mse_loss(bottle(self.reward_model,
                                        (beliefs, posterior_states)),
                                 rewards,
                                 reduction='none').mean(dim=(0, 1))  # TODO: 5

        # transition loss
        kl_loss = torch.max(
            kl_divergence(
                Independent(Normal(posterior_means, posterior_std_devs), 1),
                Independent(Normal(prior_means, prior_std_devs), 1)),
            self.free_nats).mean(dim=(0, 1))

        # print("check the reward", bottle(pcont_model, (beliefs, posterior_states)).shape, nonterminals[:-1].shape)
        if self.args.pcont:
            pcont_loss = F.binary_cross_entropy(
                bottle(self.pcont_model, (beliefs, posterior_states)),
                nonterminals)
            # pcont_pred = torch.distributions.Bernoulli(logits=bottle(self.pcont_model, (beliefs, posterior_states)))
            # pcont_loss = -pcont_pred.log_prob(nonterminals[1:]).mean(dim=(0, 1))

        return observation_loss, self.args.reward_scale * reward_loss, kl_loss, (
            self.args.pcont_scale * pcont_loss if self.args.pcont else 0)

    def _compute_loss_actor(self,
                            imag_beliefs,
                            imag_states,
                            imag_ac_logps=None):
        # reward and value prediction of imagined trajectories
        imag_rewards = bottle(self.reward_model, (imag_beliefs, imag_states))
        imag_values = bottle(self.value_model, (imag_beliefs, imag_states))
        imag_values2 = bottle(self.value_model2, (imag_beliefs, imag_states))
        imag_values = torch.min(imag_values, imag_values2)

        with torch.no_grad():
            if self.args.pcont:
                pcont = bottle(self.pcont_model, (imag_beliefs, imag_states))
            else:
                pcont = self.args.discount * torch.ones_like(imag_rewards)
        pcont = pcont.detach()

        if imag_ac_logps is not None:
            imag_values[
                1:] -= self.args.temp * imag_ac_logps  # add entropy here

        returns = cal_returns(imag_rewards[:-1],
                              imag_values[:-1],
                              imag_values[-1],
                              pcont[:-1],
                              lambda_=self.args.disclam)

        discount = torch.cumprod(
            torch.cat([torch.ones_like(pcont[:1]), pcont[:-2]], 0), 0)
        discount = discount.detach()

        assert list(discount.size()) == list(returns.size())
        actor_loss = -torch.mean(discount * returns)
        return actor_loss

    def _compute_loss_critic(self,
                             imag_beliefs,
                             imag_states,
                             imag_ac_logps=None):

        with torch.no_grad():
            # calculate the target with the target nn
            target_imag_values = bottle(self.target_value_model,
                                        (imag_beliefs, imag_states))
            target_imag_values2 = bottle(self.target_value_model2,
                                         (imag_beliefs, imag_states))
            target_imag_values = torch.min(target_imag_values,
                                           target_imag_values2)
            imag_rewards = bottle(self.reward_model,
                                  (imag_beliefs, imag_states))

            if self.args.pcont:
                pcont = bottle(self.pcont_model, (imag_beliefs, imag_states))
            else:
                pcont = self.args.discount * torch.ones_like(imag_rewards)

        # print("check pcont", pcont)
            if imag_ac_logps is not None:
                target_imag_values[1:] -= self.args.temp * imag_ac_logps

        returns = cal_returns(imag_rewards[:-1],
                              target_imag_values[:-1],
                              target_imag_values[-1],
                              pcont[:-1],
                              lambda_=self.args.disclam)
        target_return = returns.detach()

        value_pred = bottle(self.value_model, (imag_beliefs, imag_states))[:-1]
        value_pred2 = bottle(self.value_model2,
                             (imag_beliefs, imag_states))[:-1]

        value_loss = F.mse_loss(value_pred, target_return,
                                reduction="none").mean(dim=(0, 1))
        value_loss2 = F.mse_loss(value_pred2, target_return,
                                 reduction="none").mean(dim=(0, 1))
        value_loss += value_loss2

        return value_loss

    def _latent_imagination(self,
                            beliefs,
                            posterior_states,
                            with_logprob=False):
        # Rollout to generate imagined trajectories

        chunk_size, batch_size, _ = list(
            posterior_states.size())  # flatten the tensor
        flatten_size = chunk_size * batch_size

        posterior_states = posterior_states.detach().reshape(flatten_size, -1)
        beliefs = beliefs.detach().reshape(flatten_size, -1)

        imag_beliefs, imag_states, imag_ac_logps = [beliefs
                                                    ], [posterior_states], []

        for i in range(self.args.planning_horizon):
            imag_action, imag_ac_logp = self.actor_model(
                imag_beliefs[-1].detach(),
                imag_states[-1].detach(),
                deterministic=False,
                with_logprob=with_logprob,
            )
            imag_action = imag_action.unsqueeze(dim=0)  # add time dim

            # print(imag_states[-1].shape, imag_action.shape, imag_beliefs[-1].shape)
            imag_belief, imag_state, _, _ = self.transition_model(
                imag_states[-1], imag_action, imag_beliefs[-1])
            imag_beliefs.append(imag_belief.squeeze(dim=0))
            imag_states.append(imag_state.squeeze(dim=0))
            if with_logprob:
                imag_ac_logps.append(imag_ac_logp.squeeze(dim=0))

        imag_beliefs = torch.stack(imag_beliefs, dim=0).to(
            self.args.device
        )  # shape [horizon+1, (chuck-1)*batch, belief_size]
        imag_states = torch.stack(imag_states, dim=0).to(self.args.device)
        if with_logprob:
            imag_ac_logps = torch.stack(imag_ac_logps, dim=0).to(
                self.args.device)  # shape [horizon, (chuck-1)*batch]

        return imag_beliefs, imag_states, imag_ac_logps if with_logprob else None

    def update_parameters(self, gradient_steps):
        loss_info = []  # used to record loss
        for s in tqdm(range(gradient_steps)):
            # get state and belief of samples
            observations, actions, rewards, nonterminals = self.D.sample(
                self.args.batch_size, self.args.chunk_size)
            # print("check sampled rewrads", rewards)
            init_belief = torch.zeros(self.args.batch_size,
                                      self.args.belief_size,
                                      device=self.args.device)
            init_state = torch.zeros(self.args.batch_size,
                                     self.args.state_size,
                                     device=self.args.device)

            # Update belief/state using posterior from previous belief/state, previous action and current observation (over entire sequence at once)
            # beliefs, prior_states, prior_means, prior_std_devs, posterior_states, posterior_means, posterior_std_devs = self.transition_model(
            #   init_state,
            #   actions[:-1],
            #   init_belief,
            #   bottle(self.encoder, (observations[1:], )),
            #   nonterminals[:-1])

            beliefs, prior_states, prior_means, prior_std_devs, posterior_states, posterior_means, posterior_std_devs = self.transition_model(
                init_state, actions, init_belief,
                bottle(self.encoder, (observations, )),
                nonterminals)  # TODO: 4

            # update paras of world model
            world_model_loss = self._compute_loss_world(
                state=(beliefs, prior_states, prior_means, prior_std_devs,
                       posterior_states, posterior_means, posterior_std_devs),
                data=(observations, rewards, nonterminals))
            observation_loss, reward_loss, kl_loss, pcont_loss = world_model_loss
            self.world_optimizer.zero_grad()
            (observation_loss + reward_loss + kl_loss + pcont_loss).backward()
            nn.utils.clip_grad_norm_(self.world_param,
                                     self.args.grad_clip_norm,
                                     norm_type=2)
            self.world_optimizer.step()

            # freeze params to save memory
            for p in self.world_param:
                p.requires_grad = False
            for p in self.value_model.parameters():
                p.requires_grad = False
            for p in self.value_model2.parameters():
                p.requires_gard = False

            # latent imagination
            imag_beliefs, imag_states, imag_ac_logps = self._latent_imagination(
                beliefs, posterior_states, with_logprob=self.args.with_logprob)

            # update temp
            if self.args.auto_temp:
                temp_loss = -(
                    self.log_temp *
                    (imag_ac_logps[0] + self.target_entropy).detach()).mean()
                self.temp_optimizer.zero_grad()
                temp_loss.backward()
                self.temp_optimizer.step()
                self.args.temp = self.log_temp.exp()

            # update actor
            actor_loss = self._compute_loss_actor(imag_beliefs,
                                                  imag_states,
                                                  imag_ac_logps=imag_ac_logps)

            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            nn.utils.clip_grad_norm_(self.actor_model.parameters(),
                                     self.args.grad_clip_norm,
                                     norm_type=2)
            self.actor_optimizer.step()

            for p in self.world_param:
                p.requires_grad = True
            for p in self.value_model.parameters():
                p.requires_grad = True
            for p in self.value_model2.parameters():
                p.requires_grad = True

            # update critic
            imag_beliefs = imag_beliefs.detach()
            imag_states = imag_states.detach()

            critic_loss = self._compute_loss_critic(
                imag_beliefs, imag_states, imag_ac_logps=imag_ac_logps)

            self.value_optimizer.zero_grad()
            critic_loss.backward()
            nn.utils.clip_grad_norm_(self.value_model.parameters(),
                                     self.args.grad_clip_norm,
                                     norm_type=2)
            nn.utils.clip_grad_norm_(self.value_model2.parameters(),
                                     self.args.grad_clip_norm,
                                     norm_type=2)
            self.value_optimizer.step()

            loss_info.append([
                observation_loss.item(),
                reward_loss.item(),
                kl_loss.item(),
                pcont_loss.item() if self.args.pcont else 0,
                actor_loss.item(),
                critic_loss.item()
            ])

        # finally, update target value function every #gradient_steps
        with torch.no_grad():
            self.target_value_model.load_state_dict(
                self.value_model.state_dict())
        with torch.no_grad():
            self.target_value_model2.load_state_dict(
                self.value_model2.state_dict())

        return loss_info

    def infer_state(self, observation, action, belief=None, state=None):
        """ Infer belief over current state q(s_t|o≤t,a<t) from the history,
        return updated belief and posterior_state at time t
        returned shape: belief/state [belief/state_dim] (remove the time_dim)
    """
        # observation is obs.to(device), action.shape=[act_dim] (will add time dim inside this fn), belief.shape
        belief, _, _, _, posterior_state, _, _ = self.transition_model(
            state, action.unsqueeze(dim=0), belief,
            self.encoder(observation).unsqueeze(
                dim=0))  # Action and observation need extra time dimension

        belief, posterior_state = belief.squeeze(
            dim=0), posterior_state.squeeze(
                dim=0)  # Remove time dimension from belief/state

        return belief, posterior_state

    def select_action(self, state, deterministic=False):
        # get action with the inputs get from fn: infer_state; return a numpy with shape [batch, act_size]
        belief, posterior_state = state
        action, _ = self.actor_model(belief,
                                     posterior_state,
                                     deterministic=deterministic,
                                     with_logprob=False)
        if not deterministic and not self.args.with_logprob:
            print("e")
            action = Normal(action, self.args.expl_amount).rsample()

            # clip the angle
            action[:, 0].clamp_(min=self.args.angle_min,
                                max=self.args.angle_max)
            # clip the throttle
            if self.args.fix_speed:
                action[:, 1] = self.args.throttle_base
            else:
                action[:, 1].clamp_(min=self.args.throttle_min,
                                    max=self.args.throttle_max)
        print("action", action)
        # return action.cup().numpy()
        return action  # this is a Tonsor.cuda

    def import_parameters(self, params):
        # only import or export the parameters used when local rollout
        self.encoder.load_state_dict(params["encoder"])
        self.actor_model.load_state_dict(params["policy"])
        self.transition_model.load_state_dict(params["transition"])

    def export_parameters(self):
        """ return the model paras used for local rollout """
        params = {
            "encoder": self.encoder.cpu().state_dict(),
            "policy": self.actor_model.cpu().state_dict(),
            "transition": self.transition_model.cpu().state_dict()
        }

        self.encoder.to(self.args.device)
        self.actor_model.to(self.args.device)
        self.transition_model.to(self.args.device)

        return params
def main():
    """
    Training and validation.
    """

    global best_bleu4, epochs_since_improvement, checkpoint, start_epoch, fine_tune_encoder, data_name, word_map, lowest_loss_val

    # Read word map
    word_map_file = os.path.join(data_folder, 'WORDMAP_' + data_name + '.json')
    with open(word_map_file, 'r') as j:
        word_map = json.load(j)

    # Train N models and save them to each directory
    for n in range(1, args.num_models + 1):
        # Directory where the model will be saved
        model_out = os.path.join(args.model, "model_{}".format(n))
        try:
            os.mkdir(model_out)
        except:
            pass
        # Initialize / load checkpoint
        if checkpoint is None:
            decoder = DecoderWithAttention(attention_dim=attention_dim,
                                           embed_dim=emb_dim,
                                           decoder_dim=decoder_dim,
                                           vocab_size=len(word_map),
                                           dropout=dropout)
            decoder_optimizer = torch.optim.Adam(params=filter(
                lambda p: p.requires_grad, decoder.parameters()),
                                                 lr=decoder_lr)
            encoder = Encoder()
            encoder.fine_tune(fine_tune_encoder)
            encoder_optimizer = torch.optim.Adam(
                params=filter(lambda p: p.requires_grad, encoder.parameters()),
                lr=encoder_lr) if fine_tune_encoder else None
        else:
            checkpoint = torch.load(checkpoint)
            start_epoch = checkpoint['epoch'] + 1
            epochs_since_improvement = checkpoint['epochs_since_improvement']
            best_bleu4 = checkpoint['bleu-4']
            decoder = checkpoint['decoder']
            decoder_optimizer = checkpoint['decoder_optimizer']
            encoder = checkpoint['encoder']
            encoder_optimizer = checkpoint['encoder_optimizer']
            if fine_tune_encoder is True and encoder_optimizer is None:
                encoder.fine_tune(fine_tune_encoder)
                encoder_optimizer = torch.optim.Adam(params=filter(
                    lambda p: p.requires_grad, encoder.parameters()),
                                                     lr=encoder_lr)

        # Move to GPU, if available
        decoder = decoder.to(device)
        encoder = encoder.to(device)

        # Loss function
        criterion = nn.CrossEntropyLoss().to(device)

        # Custom dataloaders
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
        train_loader = torch.utils.data.DataLoader(CaptionDataset(
            data_folder,
            data_name,
            'TRAIN',
            transform=transforms.Compose([normalize])),
                                                   batch_size=batch_size,
                                                   shuffle=True,
                                                   num_workers=workers,
                                                   pin_memory=True)
        val_loader = torch.utils.data.DataLoader(CaptionDataset(
            data_folder,
            data_name,
            'VAL',
            transform=transforms.Compose([normalize])),
                                                 batch_size=batch_size,
                                                 shuffle=True,
                                                 num_workers=workers,
                                                 pin_memory=True)

        # Epochs
        for epoch in range(start_epoch, epochs):
            # Decay learning rate if there is no improvement for 20 consecutive epochs
            # and terminate training after 50 consecutive epochs
            if epochs_since_improvement == 50:
                break
            if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:
                adjust_learning_rate(decoder_optimizer, 0.8)
                if fine_tune_encoder:
                    adjust_learning_rate(encoder_optimizer, 0.8)

            # One epoch's training
            train(train_loader=train_loader,
                  encoder=encoder,
                  decoder=decoder,
                  criterion=criterion,
                  encoder_optimizer=encoder_optimizer,
                  decoder_optimizer=decoder_optimizer,
                  epoch=epoch)

            # One epoch's validation
            recent_loss, recent_bleu4 = validate(val_loader=val_loader,
                                                 encoder=encoder,
                                                 decoder=decoder,
                                                 criterion=criterion)
            # Check if there was an improvement using bleu
            is_best = recent_bleu4 > best_bleu4
            best_bleu4 = max(recent_bleu4, best_bleu4)
            # Check if there was an improvement using loss
            #is_best = recent_loss < lowest_loss_val
            #lowest_loss_val = min(recent_loss, lowest_loss_val)
            if not is_best:
                epochs_since_improvement += 1
                print("\nEpochs since last improvement: %d\n" %
                      (epochs_since_improvement, ))
            else:
                epochs_since_improvement = 0

            #save_checkpoint_with_dir(model_out, data_name, epoch, epochs_since_improvement,
            #                         encoder, decoder, encoder_optimizer,
            #                         decoder_optimizer, lowest_loss_val, is_best)
            save_checkpoint_with_dir(model_out, data_name, epoch,
                                     epochs_since_improvement, encoder,
                                     decoder, encoder_optimizer,
                                     decoder_optimizer, recent_bleu4, is_best)
        # Delete encoder&decoder objects and reset memory
        del decoder
        del encoder
        torch.cuda.empty_cache()
        # Reset epochs since improvement to 0 for a new round of training
        epochs_since_improvement = 0
        best_bleu4, start_epoch = 0, 0
        check_point = None
        fine_tune_encoder = False
Ejemplo n.º 8
0
def main():
    """
    Training and validation.
    """

    global best_bleu4, epochs_since_improvement, checkpoint, start_epoch, fine_tune_encoder, data_name, word_map

    # Read word map
    word_map_file = os.path.join(data_folder, 'WORDMAP_' + data_name + '.json')
    with open(word_map_file, 'r') as j:
        word_map = json.load(j)

    # Load Pretrained Embeddings and compare to Wordmap if True, otherwise reload the pickle file
    if reload_pretrained_embed == True:
        embeddings_index = dict()
        fid = open(pretrained_embeddings_file, encoding="utf8")
        for line in fid:
            values = line.split()
            word = values[0]
            coefs = np.asarray(values[1:], dtype='float32')
            embeddings_index[word] = coefs
        fid.close()

        pretrained_embeddings = torch.zeros((len(word_map) + 1, emb_dim))

        for word, idx in word_map.items():
            embed_vector = embeddings_index.get(word)
            if embed_vector is not None:
                # words not found in embedding index will be all-zeros.
                pretrained_embeddings[idx] = torch.from_numpy(embed_vector)
            else:
                pretrained_embeddings[idx] = torch.from_numpy(
                    np.random.uniform(-1, 1, emb_dim))

    #   print(pretrained_embeddings[0:2, :])

    #   fid = open("embedding_matrix.pkl","wb")
    #   dump(pretrained_embeddings, fid)
    #   fid.close()

    # else:
    #   pretrained_embeddings = open(pretrained_embedding_matrix, "wb")
    #   print('Successfully Loaded Pretrained Embeddings Pickle')

    # Initialize / load checkpoint
    if checkpoint is None:
        decoder = DecoderWithAttention(attention_dim=attention_dim,
                                       embed_dim=emb_dim,
                                       decoder_dim=decoder_dim,
                                       vocab_size=len(word_map),
                                       dropout=dropout)
        # decoder.load_pretrained_embeddings(pretrained_embeddings)  # pretrained_embeddings should be of dimensions (len(word_map), emb_dim)
        # decoder.fine_tune_embeddings(True)
        decoder_optimizer = torch.optim.Adam(params=filter(
            lambda p: p.requires_grad, decoder.parameters()),
                                             lr=args['decoder_lr'])

        # encoder = vgg_face_dag() #VGG Face
        encoder = Encoder()  #OG Encoder
        # encoder.cuda()
        # print(summary(encoder, (3, 224, 224)))
        # print('ENCODER SUMMARY')
        encoder.fine_tune(fine_tune_encoder)
        encoder_optimizer = torch.optim.Adam(
            params=filter(lambda p: p.requires_grad, encoder.parameters()),
            lr=args['encoder_lr']) if fine_tune_encoder else None

    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        best_bleu4 = checkpoint['bleu-4']
        decoder = checkpoint['decoder']
        decoder_optimizer = checkpoint['decoder_optimizer']
        encoder = checkpoint['encoder']
        encoder_optimizer = checkpoint['encoder_optimizer']
        if fine_tune_encoder is True and encoder_optimizer is None:
            encoder.fine_tune(fine_tune_encoder)
            encoder_optimizer = torch.optim.Adam(params=filter(
                lambda p: p.requires_grad, encoder.parameters()),
                                                 lr=args['encoder_lr'])

    # Move to GPU, if available
    decoder = decoder.to(device)
    encoder = encoder.to(device)

    # Loss function
    criterion = nn.CrossEntropyLoss().to(device)

    # Custom dataloaders
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])  #OG figures
    # normalize = transforms.Normalize(mean= [129.186279296875, 104.76238250732422, 93.59396362304688], #VGG Face figures
    #                                  std= [1, 1, 1])

    train_loader = torch.utils.data.DataLoader(CaptionDataset(
        data_folder,
        data_name,
        'TRAIN',
        transform=transforms.Compose([normalize])),
                                               batch_size=args['batch_size'],
                                               shuffle=True,
                                               num_workers=workers,
                                               pin_memory=True)

    # print('validation_loader')
    val_loader = torch.utils.data.DataLoader(CaptionDataset(
        data_folder,
        data_name,
        'VAL',
        transform=transforms.Compose([normalize])),
                                             batch_size=args['batch_size'],
                                             shuffle=True,
                                             num_workers=workers,
                                             pin_memory=True)

    # Epochs
    for epoch in range(start_epoch, args['epochs']):

        # Decay learning rate if there is no improvement for 8 consecutive epochs, and terminate training after 15
        if epochs_since_improvement == 15:
            break
        if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:
            adjust_learning_rate(decoder_optimizer, 0.8)
            if fine_tune_encoder:
                adjust_learning_rate(encoder_optimizer, 0.8)

        # One epoch's training
        train(train_loader=train_loader,
              encoder=encoder,
              decoder=decoder,
              criterion=criterion,
              encoder_optimizer=encoder_optimizer,
              decoder_optimizer=decoder_optimizer,
              epoch=epoch)

        # One epoch's validation
        # print('validation_loader_2')
        recent_bleu4 = validate(val_loader=val_loader,
                                encoder=encoder,
                                decoder=decoder,
                                criterion=criterion,
                                epoch=epoch)

        # Check if there was an improvement
        is_best = recent_bleu4 > best_bleu4
        best_bleu4 = max(recent_bleu4, best_bleu4)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" %
                  (epochs_since_improvement, ))
        else:
            epochs_since_improvement = 0
        tensorboard_writer.add_scalar('BLEU-4/epoch', recent_bleu4, epoch)

        # Save checkpoint
        save_checkpoint(data_name, epoch, epochs_since_improvement, encoder,
                        decoder, encoder_optimizer, decoder_optimizer,
                        recent_bleu4, is_best)

    PATH = './cifar_net.pth'
    tensorboard_writer.close()

    print('Task ID number is: {}'.format(task.id))