Esempio n. 1
0
class HiDDen(object):
    def __init__(self, config: HiDDenConfiguration, device: torch.device):
        self.enc_dec = EncoderDecoder(config).to(device)
        self.discr = Discriminator(config).to(device)
        self.opt_enc_dec = torch.optim.Adam(self.enc_dec.parameters())
        self.opt_discr = torch.optim.Adam(self.discr.parameters())

        self.config = config
        self.device = device
        self.bce_with_logits_loss = nn.BCEWithLogitsLoss().to(device)
        self.mse_loss = nn.MSELoss().to(device)

        self.cover_label = 1
        self.encod_label = 0

    def train_on_batch(self, batch: list):
        '''
        Trains the network on a single batch consistring images and messages
        '''
        images, messages = batch
        batch_size = images.shape[0]
        self.enc_dec.train()
        self.discr.train()

        with torch.enable_grad():
            # ---------- Train the discriminator----------
            self.opt_discr.zero_grad()

            # train on cover
            d_target_label_cover = torch.full((batch_size, 1),
                                              self.cover_label,
                                              device=self.device)
            d_target_label_encoded = torch.full((batch_size, 1),
                                                self.encod_label,
                                                device=self.device)
            g_target_label_encoded = torch.full((batch_size, 1),
                                                self.cover_label,
                                                device=self.device)

            d_on_cover = self.discr(images)
            d_loss_on_cover = self.bce_with_logits_loss(
                d_on_cover, d_target_label_cover)
            d_loss_on_cover.backward()

            # train on fake
            encoded_images, decoded_messages = self.enc_dec(images, messages)
            d_on_encoded = self.discr(encoded_images.detach())
            d_loss_on_encod = self.bce_with_logits_loss(
                d_on_encoded, d_target_label_encoded)
            d_loss_on_encod.backward()
            self.opt_discr.step()

            #---------- Train the generator----------
            self.opt_enc_dec.zero_grad()

            d_on_encoded_for_enc = self.discr(encoded_images)
            g_loss_adv = self.bce_with_logits_loss(d_on_encoded_for_enc,
                                                   g_target_label_encoded)
            g_loss_enc = self.mse_loss(encoded_images, images)
            g_loss_dec = self.mse_loss(decoded_messages, messages)

            g_loss = self.config.adversarial_loss * g_loss_adv \
                    + self.config.encoder_loss * g_loss_enc \
                    + self.config.decoder_loss * g_loss_dec
            g_loss.backward()
            self.opt_enc_dec.step()

        decoded_rounded = decoded_messages.detach().cpu().numpy().round().clip(
            0, 1)
        bitwise_err = np.sum(np.abs(decoded_rounded - messages.detach().cpu().numpy())) \
                      / (batch_size * messages.shape[1])

        losses = {
            'loss': g_loss.item(),
            'encoder_mse': g_loss_enc.item(),
            'decoder_mse': g_loss_dec.item(),
            'bitwise-error': bitwise_err,
            'adversarial_bce': g_loss_adv.item(),
            'discr_cover_bce': d_loss_on_cover.item(),
            'discr_encod_bce': d_loss_on_encod.item()
        }

        return losses, (encoded_images, decoded_messages)

    def validate_on_batch(self, batch: list):
        '''Run validation on a batch consist of [images, messages]'''
        images, messages = batch
        batch_size = images.shape[0]

        self.enc_dec.eval()
        self.discr.eval()

        with torch.no_grad():
            d_target_label_cover = torch.full((batch_size, 1),
                                              self.cover_label,
                                              device=self.device)
            d_target_label_encoded = torch.full((batch_size, 1),
                                                self.encod_label,
                                                device=self.device)
            g_target_label_encoded = torch.full((batch_size, 1),
                                                self.cover_label,
                                                device=self.device)

            d_on_cover = self.discr(images)
            d_loss_on_cover = self.bce_with_logits_loss(
                d_on_cover, d_target_label_cover)

            encoded_images, decoded_messages = self.enc_dec(images, messages)
            d_on_encoded = self.discr(encoded_images)
            d_loss_on_encod = self.bce_with_logits_loss(
                d_on_encoded, d_target_label_encoded)

            d_on_encoded_for_enc = self.discr(encoded_images)
            g_loss_adv = self.bce_with_logits_loss(d_on_encoded_for_enc,
                                                   g_target_label_encoded)
            g_loss_enc = self.mse_loss(encoded_images, images)
            g_loss_dec = self.mse_loss(decoded_messages, messages)

            g_loss = self.config.adversarial_loss * g_loss_adv \
                    + self.config.encoder_loss * g_loss_enc \
                    + self.config.decoder_loss * g_loss_dec

        decoded_rounded = decoded_messages.detach().cpu().numpy().round().clip(
            0, 1)
        bitwise_err = np.sum(np.abs(decoded_rounded - messages.detach().cpu().numpy()))\
                     / (batch_size * messages.shape[1])

        losses = {
            'loss': g_loss.item(),
            'encoder_mse': g_loss_enc.item(),
            'decoder_mse': g_loss_dec.item(),
            'bitwise-err': bitwise_err,
            'adversarial_bce': g_loss_adv.item(),
            'discr_cover_bce': d_loss_on_cover.item(),
            'discr_enced_bce': d_loss_on_encod.item()
        }

        return losses, (encoded_images, decoded_messages)

    def to_stirng(self):
        return f'{str(self.enc_dec)}\n{str(self.discr)}'
Esempio n. 2
0
class Hidden:
    def __init__(self, configuration: HiDDenConfiguration,
                 device: torch.device, noiser: Noiser, tb_logger):
        """
        :param configuration: Configuration for the net, such as the size of the input image, number of channels in the intermediate layers, etc.
        :param device: torch.device object, CPU or GPU
        :param noiser: Object representing stacked noise layers.
        :param tb_logger: Optional TensorboardX logger object, if specified -- enables Tensorboard logging
        """
        super(Hidden, self).__init__()

        self.encoder_decoder = EncoderDecoder(configuration, noiser).to(device)
        self.discriminator = Discriminator(configuration).to(device)
        self.optimizer_enc_dec = torch.optim.Adam(
            self.encoder_decoder.parameters())
        self.optimizer_discrim = torch.optim.Adam(
            self.discriminator.parameters())

        if configuration.use_vgg:
            self.vgg_loss = VGGLoss(3, 1, False)
            self.vgg_loss.to(device)
        else:
            self.vgg_loss = None

        self.config = configuration
        self.device = device

        self.bce_with_logits_loss = nn.BCEWithLogitsLoss().to(device)
        self.mse_loss = nn.MSELoss().to(device)

        # Defined the labels used for training the discriminator/adversarial loss
        self.cover_label = 1
        self.encoded_label = 0

        self.tb_logger = tb_logger
        if tb_logger is not None:
            from tensorboard_logger import TensorBoardLogger
            encoder_final = self.encoder_decoder.encoder._modules[
                'final_layer']
            encoder_final.weight.register_hook(
                tb_logger.grad_hook_by_name('grads/encoder_out'))
            decoder_final = self.encoder_decoder.decoder._modules['linear']
            decoder_final.weight.register_hook(
                tb_logger.grad_hook_by_name('grads/decoder_out'))
            discrim_final = self.discriminator._modules['linear']
            discrim_final.weight.register_hook(
                tb_logger.grad_hook_by_name('grads/discrim_out'))

    def train_on_batch(self, batch: list):
        """
        Trains the network on a single batch consisting of images and messages
        :param batch: batch of training data, in the form [images, messages]
        :return: dictionary of error metrics from Encoder, Decoder, and Discriminator on the current batch
        """
        images, messages = batch

        batch_size = images.shape[0]
        self.encoder_decoder.train()
        self.discriminator.train()
        with torch.enable_grad():
            # ---------------- Train the discriminator -----------------------------
            self.optimizer_discrim.zero_grad()
            # train on cover
            d_target_label_cover = torch.full((batch_size, 1),
                                              self.cover_label,
                                              device=self.device)
            d_target_label_encoded = torch.full((batch_size, 1),
                                                self.encoded_label,
                                                device=self.device)
            g_target_label_encoded = torch.full((batch_size, 1),
                                                self.cover_label,
                                                device=self.device)

            d_on_cover = self.discriminator(images)
            d_loss_on_cover = self.bce_with_logits_loss(
                d_on_cover, d_target_label_cover)
            # d_loss_on_cover.backward()

            # train on fake
            encoded_images, noised_images, decoded_messages = self.encoder_decoder(
                images, messages)
            d_on_encoded = self.discriminator(encoded_images.detach())
            d_loss_on_encoded = self.bce_with_logits_loss(
                d_on_encoded, d_target_label_encoded)

            # d_loss_on_encoded.backward()
            # self.optimizer_discrim.step()

            # --------------Train the generator (encoder-decoder) ---------------------
            self.optimizer_enc_dec.zero_grad()
            # target label for encoded images should be 'cover', because we want to fool the discriminator
            d_on_encoded_for_enc = self.discriminator(encoded_images)
            g_loss_adv = self.bce_with_logits_loss(d_on_encoded_for_enc,
                                                   g_target_label_encoded)

            if self.vgg_loss == None:
                g_loss_enc = self.mse_loss(encoded_images, images)
            else:
                vgg_on_cov = self.vgg_loss(images)
                vgg_on_enc = self.vgg_loss(encoded_images)
                g_loss_enc = self.mse_loss(vgg_on_cov, vgg_on_enc)

            g_loss_dec = self.mse_loss(decoded_messages, messages)
            g_loss = self.config.adversarial_loss * g_loss_adv + self.config.encoder_loss * g_loss_enc \
                     + self.config.decoder_loss * g_loss_dec

            g_loss.backward()
            self.optimizer_enc_dec.step()

        decoded_rounded = decoded_messages.detach().cpu().numpy().round().clip(
            0, 1)
        bitwise_avg_err = np.sum(
            np.abs(decoded_rounded - messages.detach().cpu().numpy())) / (
                batch_size * messages.shape[1])

        losses = {
            'loss           ': g_loss.item(),
            'encoder_mse    ': g_loss_enc.item(),
            'dec_mse        ': g_loss_dec.item(),
            'bitwise-error  ': bitwise_avg_err,
            'adversarial_bce': g_loss_adv.item(),
            'discr_cover_bce': d_loss_on_cover.item(),
            'discr_encod_bce': d_loss_on_encoded.item()
        }
        return losses, (encoded_images, noised_images, decoded_messages)

    def validate_on_batch(self, batch: list):
        """
        Runs validation on a single batch of data consisting of images and messages
        :param batch: batch of validation data, in form [images, messages]
        :return: dictionary of error metrics from Encoder, Decoder, and Discriminator on the current batch
        """
        # if TensorboardX logging is enabled, save some of the tensors.
        if self.tb_logger is not None:
            encoder_final = self.encoder_decoder.encoder._modules[
                'final_layer']
            self.tb_logger.add_tensor('weights/encoder_out',
                                      encoder_final.weight)
            decoder_final = self.encoder_decoder.decoder._modules['linear']
            self.tb_logger.add_tensor('weights/decoder_out',
                                      decoder_final.weight)
            discrim_final = self.discriminator._modules['linear']
            self.tb_logger.add_tensor('weights/discrim_out',
                                      discrim_final.weight)

        images, messages = batch

        batch_size = images.shape[0]

        self.encoder_decoder.eval()
        self.discriminator.eval()
        with torch.no_grad():
            d_target_label_cover = torch.full((batch_size, 1),
                                              self.cover_label,
                                              device=self.device)
            d_target_label_encoded = torch.full((batch_size, 1),
                                                self.encoded_label,
                                                device=self.device)
            g_target_label_encoded = torch.full((batch_size, 1),
                                                self.cover_label,
                                                device=self.device)

            d_on_cover = self.discriminator(images)
            d_loss_on_cover = self.bce_with_logits_loss(
                d_on_cover, d_target_label_cover)

            encoded_images, noised_images, decoded_messages = self.encoder_decoder(
                images, messages)

            d_on_encoded = self.discriminator(encoded_images)
            d_loss_on_encoded = self.bce_with_logits_loss(
                d_on_encoded, d_target_label_encoded)

            d_on_encoded_for_enc = self.discriminator(encoded_images)
            g_loss_adv = self.bce_with_logits_loss(d_on_encoded_for_enc,
                                                   g_target_label_encoded)

            if self.vgg_loss is None:
                g_loss_enc = self.mse_loss(encoded_images, images)
            else:
                vgg_on_cov = self.vgg_loss(images)
                vgg_on_enc = self.vgg_loss(encoded_images)
                g_loss_enc = self.mse_loss(vgg_on_cov, vgg_on_enc)

            g_loss_dec = self.mse_loss(decoded_messages, messages)
            g_loss = self.config.adversarial_loss * g_loss_adv + self.config.encoder_loss * g_loss_enc \
                     + self.config.decoder_loss * g_loss_dec

        decoded_rounded = decoded_messages.detach().cpu().numpy().round().clip(
            0, 1)
        bitwise_avg_err = np.sum(
            np.abs(decoded_rounded - messages.detach().cpu().numpy())) / (
                batch_size * messages.shape[1])

        losses = {
            'loss           ': g_loss.item(),
            'encoder_mse    ': g_loss_enc.item(),
            'dec_mse        ': g_loss_dec.item(),
            'bitwise-error  ': bitwise_avg_err,
            'adversarial_bce': g_loss_adv.item(),
            'discr_cover_bce': d_loss_on_cover.item(),
            'discr_encod_bce': d_loss_on_encoded.item()
        }
        return losses, (encoded_images, noised_images, decoded_messages)

    def to_stirng(self):
        return '{}\n{}'.format(str(self.encoder_decoder),
                               str(self.discriminator))
Esempio n. 3
0
def main(args):
    if args.model_name is not None:
        print('Preparing to train model: {}'.format(args.model_name))

    global device
    device = torch.device(
        'cuda' if torch.cuda.is_available() and not args.cpu else 'cpu')

    sc_will_happen = args.self_critical_from_epoch != -1

    if args.validate is None and args.lr_scheduler == 'ReduceLROnPlateau':
        print(
            'ERROR: you need to enable validation in order to use default lr_scheduler (ReduceLROnPlateau)'
        )
        print('Hint: use something like --validate=coco:val2017')
        sys.exit(1)

    # Create model directory
    if not os.path.exists(args.model_path):
        os.makedirs(args.model_path)

    # Image preprocessing, normalization for the pretrained resnet
    transform = transforms.Compose([
        # transforms.Resize((256, 256)),
        transforms.RandomCrop(args.crop_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    scorers = {}
    if args.validation_scoring is not None or sc_will_happen:
        assert not (
            args.validation_scoring is None and sc_will_happen
        ), "Please provide a metric when using self-critical training"
        for s in args.validation_scoring.split(','):
            s = s.lower().strip()
            if s == 'cider':
                from eval.cider import Cider
                scorers['CIDEr'] = Cider()
            if s == 'ciderd':
                from eval.ciderD.ciderD import CiderD
                scorers['CIDEr-D'] = CiderD(df=args.cached_words)

    ########################
    # Set Model parameters #
    ########################

    # Store parameters gotten from arguments separately:
    arg_params = ModelParams.fromargs(args)

    print("Model parameters inferred from command arguments: ")
    print(arg_params)
    start_epoch = 0

    ###############################
    # Load existing model state   #
    # and update Model parameters #
    ###############################

    state = None

    if args.load_model:
        try:
            state = torch.load(args.load_model, map_location=device)
        except AttributeError:
            print(
                'WARNING: Old model found. Please use model_update.py in the model before executing this script.'
            )
            exit(1)
        new_external_features = arg_params.features.external

        params = ModelParams(state, arg_params=arg_params)
        if len(new_external_features
               ) and params.features.external != new_external_features:
            print('WARNING: external features changed: ',
                  params.features.external, new_external_features)
            print('Updating feature paths...')
            params.update_ext_features(new_external_features)
        start_epoch = state['epoch']
        print('Loaded model {} at epoch {}'.format(args.load_model,
                                                   start_epoch))
    else:
        params = arg_params
        params.command_history = []

    if params.rnn_hidden_init == 'from_features' and params.skip_start_token:
        print(
            "ERROR: Please remove --skip_start_token if you want to use image features "
            " to initialize hidden and cell states. <start> token is needed to trigger "
            " the process of sequence generation, since we don't have image features "
            " embedding as the first input token.")
        sys.exit(1)

    # Force set the following hierarchical model parameters every time:
    if arg_params.hierarchical_model:
        params.hierarchical_model = True
        params.max_sentences = arg_params.max_sentences
        params.weight_sentence_loss = arg_params.weight_sentence_loss
        params.weight_word_loss = arg_params.weight_word_loss
        params.dropout_stopping = arg_params.dropout_stopping
        params.dropout_fc = arg_params.dropout_fc
        params.coherent_sentences = arg_params.coherent_sentences
        params.coupling_alpha = arg_params.coupling_alpha
        params.coupling_beta = arg_params.coupling_beta

    assert args.replace or \
        not os.path.isdir(os.path.join(args.output_root, args.model_path, get_model_name(args, params))) or \
        not (args.load_model and not args.validate_only), \
        '{} already exists. If you want to replace it or resume training please use --replace flag. ' \
        'If you want to validate a loaded model without training it, use --validate_only flag.'  \
        'Otherwise specify a different model name using --model_name flag.'\
        .format(os.path.join(args.output_root, args.model_path, get_model_name(args, params)))

    if args.load_model:
        print("Final model parameters (loaded model + command arguments): ")
        print(params)

    ##############################
    # Load dataset configuration #
    ##############################

    dataset_configs = DatasetParams(args.dataset_config_file)

    if args.dataset is None and not args.validate_only:
        print('ERROR: No dataset selected!')
        print(
            'Please supply a training dataset with the argument --dataset DATASET'
        )
        print('The following datasets are configured in {}:'.format(
            args.dataset_config_file))
        for ds, _ in dataset_configs.config.items():
            if ds not in ('DEFAULT', 'generic'):
                print(' ', ds)
        sys.exit(1)

    if args.validate_only:
        if args.load_model is None:
            print(
                'ERROR: for --validate_only you need to specify a model to evaluate using --load_model MODEL'
            )
            sys.exit(1)
    else:
        dataset_params = dataset_configs.get_params(args.dataset)

        for i in dataset_params:
            i.config_dict['no_tokenize'] = args.no_tokenize
            i.config_dict['show_tokens'] = args.show_tokens
            i.config_dict['skip_start_token'] = params.skip_start_token

            if params.hierarchical_model:
                i.config_dict['hierarchical_model'] = True
                i.config_dict['max_sentences'] = params.max_sentences
                i.config_dict['crop_regions'] = False

    if args.validate is not None:
        validation_dataset_params = dataset_configs.get_params(args.validate)
        for i in validation_dataset_params:
            i.config_dict['no_tokenize'] = args.no_tokenize
            i.config_dict['show_tokens'] = args.show_tokens
            i.config_dict['skip_start_token'] = params.skip_start_token

            if params.hierarchical_model:
                i.config_dict['hierarchical_model'] = True
                i.config_dict['max_sentences'] = params.max_sentences
                i.config_dict['crop_regions'] = False

    #######################
    # Load the vocabulary #
    #######################

    # For pre-trained models attempt to obtain
    # saved vocabulary from the model itself:
    if args.load_model and params.vocab is not None:
        print("Loading vocabulary from the model file:")
        vocab = params.vocab
    else:
        if args.vocab is None:
            print(
                "ERROR: You must specify the vocabulary to be used for training using "
                "--vocab flag.\nTry --vocab AUTO if you want the vocabulary to be "
                "either generated from the training dataset or loaded from cache."
            )
            sys.exit(1)
        print("Loading / generating vocabulary:")
        vocab = get_vocab(args, dataset_params)

    print('Size of the vocabulary is {}'.format(len(vocab)))

    ##########################
    # Initialize data loader #
    ##########################

    ext_feature_sets = [
        params.features.external, params.persist_features.external
    ]
    if not args.validate_only:
        print('Loading dataset: {} with {} workers'.format(
            args.dataset, args.num_workers))
        if params.skip_start_token:
            print("Skipping the use of <start> token...")
        data_loader, ef_dims = get_loader(
            dataset_params,
            vocab,
            transform,
            args.batch_size,
            shuffle=True,
            num_workers=args.num_workers,
            ext_feature_sets=ext_feature_sets,
            skip_images=not params.has_internal_features(),
            verbose=args.verbose,
            unique_ids=sc_will_happen)
        if sc_will_happen:
            gts_sc = get_ground_truth_captions(data_loader.dataset)

    gts_sc_valid = None
    if args.validate is not None:
        valid_loader, ef_dims = get_loader(
            validation_dataset_params,
            vocab,
            transform,
            args.batch_size,
            shuffle=True,
            num_workers=args.num_workers,
            ext_feature_sets=ext_feature_sets,
            skip_images=not params.has_internal_features(),
            verbose=args.verbose)
        gts_sc_valid = get_ground_truth_captions(
            valid_loader.dataset) if sc_will_happen else None

    #########################################
    # Setup (optional) TensorBoardX logging #
    #########################################

    writer = None
    if args.tensorboard:
        if SummaryWriter is not None:
            model_name = get_model_name(args, params)
            timestamp = datetime.now().strftime('%Y%m%d%H%M%S')
            log_dir = os.path.join(
                args.output_root, 'log_tb/{}_{}'.format(model_name, timestamp))
            writer = SummaryWriter(log_dir=log_dir)
            print("INFO: Logging TensorBoardX events to {}".format(log_dir))
        else:
            print(
                "WARNING: SummaryWriter object not available. "
                "Hint: Please install TensorBoardX using pip install tensorboardx"
            )

    ######################
    # Build the model(s) #
    ######################

    # Set per parameter learning rate here, if supplied by the user:

    if args.lr_word_decoder is not None:
        if not params.hierarchical_model:
            print(
                "ERROR: Setting word decoder learning rate currently supported in Hierarchical Model only."
            )
            sys.exit(1)

        lr_dict = {'word_decoder': args.lr_word_decoder}
    else:
        lr_dict = {}

    model = EncoderDecoder(params,
                           device,
                           len(vocab),
                           state,
                           ef_dims,
                           lr_dict=lr_dict)

    ######################
    # Optimizer and loss #
    ######################

    sc_activated = False
    opt_params = model.get_opt_params()

    # Loss and optimizer
    if params.hierarchical_model:
        criterion = HierarchicalXEntropyLoss(
            weight_sentence_loss=params.weight_sentence_loss,
            weight_word_loss=params.weight_word_loss)
    elif args.share_embedding_weights:
        criterion = SharedEmbeddingXentropyLoss(param_lambda=0.15)
    else:
        criterion = nn.CrossEntropyLoss()

    if sc_will_happen:  # save it for later
        if args.self_critical_loss == 'sc':
            from model.loss import SelfCriticalLoss
            rl_criterion = SelfCriticalLoss()
        elif args.self_critical_loss == 'sc_with_diversity':
            from model.loss import SelfCriticalWithDiversityLoss
            rl_criterion = SelfCriticalWithDiversityLoss()
        elif args.self_critical_loss == 'sc_with_relative_diversity':
            from model.loss import SelfCriticalWithRelativeDiversityLoss
            rl_criterion = SelfCriticalWithRelativeDiversityLoss()
        elif args.self_critical_loss == 'sc_with_bleu_diversity':
            from model.loss import SelfCriticalWithBLEUDiversityLoss
            rl_criterion = SelfCriticalWithBLEUDiversityLoss()
        elif args.self_critical_loss == 'sc_with_repetition':
            from model.loss import SelfCriticalWithRepetitionLoss
            rl_criterion = SelfCriticalWithRepetitionLoss()
        elif args.self_critical_loss == 'mixed':
            from model.loss import MixedLoss
            rl_criterion = MixedLoss()
        elif args.self_critical_loss == 'mixed_with_face':
            from model.loss import MixedWithFACELoss
            rl_criterion = MixedWithFACELoss(vocab_size=len(vocab))
        elif args.self_critical_loss in [
                'sc_with_penalty', 'sc_with_penalty_throughout',
                'sc_masked_tokens'
        ]:
            raise ValueError('Deprecated loss, use \'sc\' loss')
        else:
            raise ValueError('Invalid self-critical loss')

        print('Selected self-critical loss is', rl_criterion)

        if start_epoch >= args.self_critical_from_epoch:
            criterion = rl_criterion
            sc_activated = True
            print('Self-critical loss training begins')

    # When using CyclicalLR, default learning rate should be always 1.0
    if args.lr_scheduler == 'CyclicalLR':
        default_lr = 1.
    else:
        default_lr = 0.001

    if sc_activated:
        optimizer = torch.optim.Adam(
            opt_params,
            lr=args.learning_rate if args.learning_rate else 5e-5,
            weight_decay=args.weight_decay)
    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(opt_params,
                                     lr=default_lr,
                                     weight_decay=args.weight_decay)
    elif args.optimizer == 'rmsprop':
        optimizer = torch.optim.RMSprop(opt_params,
                                        lr=default_lr,
                                        weight_decay=args.weight_decay)
    elif args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(opt_params,
                                    lr=default_lr,
                                    weight_decay=args.weight_decay)
    else:
        print('ERROR: unknown optimizer:', args.optimizer)
        sys.exit(1)

    # We don't want to initialize the optimizer if we are transfering
    # the language model from the regular model to hierarchical model
    transfer_language_model = False

    if arg_params.hierarchical_model and state and not state.get(
            'hierarchical_model'):
        transfer_language_model = True

    # Set optimizer state to the one found in a loaded model, unless
    # we are doing a transfer learning step from flat to hierarchical model,
    # or we are using self-critical loss,
    # or the number of unique parameter groups has changed, or the user
    # has explicitly told us *not to* reuse optimizer parameters from before
    if state and not transfer_language_model and not sc_activated and not args.optimizer_reset:
        # Check that number of parameter groups is the same
        if len(optimizer.param_groups) == len(
                state['optimizer']['param_groups']):
            optimizer.load_state_dict(state['optimizer'])

    # override lr if set explicitly in arguments -
    # 1) Global learning rate:
    if args.learning_rate:
        for param_group in optimizer.param_groups:
            param_group['lr'] = args.learning_rate
        params.learning_rate = args.learning_rate
    else:
        params.learning_rate = default_lr

    # 2) Parameter-group specific learning rate:
    if args.lr_word_decoder is not None:
        # We want to give user an option to set learning rate for word_decoder
        # separately. Other exceptions can be added as needed:
        for param_group in optimizer.param_groups:
            if param_group.get('name') == 'word_decoder':
                param_group['lr'] = args.lr_word_decoder
                break

    if args.validate is not None and args.lr_scheduler == 'ReduceLROnPlateau':
        print('Using ReduceLROnPlateau learning rate scheduler')
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                               'min',
                                                               verbose=True,
                                                               patience=2)
    elif args.lr_scheduler == 'StepLR':
        print('Using StepLR learning rate scheduler with step_size {}'.format(
            args.lr_step_size))
        # Decrease the learning rate by the factor of gamma at every
        # step_size epochs (for example every 5 or 10 epochs):
        step_size = args.lr_step_size
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                    step_size,
                                                    gamma=0.5,
                                                    last_epoch=-1)
    elif args.lr_scheduler == 'CyclicalLR':
        print(
            "Using Cyclical learning rate scheduler, lr range: [{},{}]".format(
                args.lr_cyclical_min, args.lr_cyclical_max))

        step_size = len(data_loader)
        clr = cyclical_lr(step_size,
                          min_lr=args.lr_cyclical_min,
                          max_lr=args.lr_cyclical_max)
        n_groups = len(optimizer.param_groups)
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                      [clr] * n_groups)
    elif args.lr_scheduler is not None:
        print('ERROR: Invalid learing rate scheduler specified: {}'.format(
            args.lr_scheduler))
        sys.exit(1)

    ###################
    # Train the model #
    ###################

    stats_postfix = None
    if args.validate_only:
        stats_postfix = args.validate
    if args.load_model:
        all_stats = init_stats(args, params, postfix=stats_postfix)
    else:
        all_stats = {}

    if args.force_epoch:
        start_epoch = args.force_epoch - 1

    if not args.validate_only:
        total_step = len(data_loader)
        print(
            'Start training with start_epoch={:d} num_epochs={:d} num_batches={:d} ...'
            .format(start_epoch, args.num_epochs, args.num_batches))

    if args.teacher_forcing != 'always':
        print('\t k: {}'.format(args.teacher_forcing_k))
        print('\t beta: {}'.format(args.teacher_forcing_beta))
    print('Optimizer:', optimizer)

    if args.validate_only:
        stats = {}
        teacher_p = 1.0
        if args.teacher_forcing != 'always':
            print(
                'WARNING: teacher_forcing!=always, not yet implemented for --validate_only mode'
            )

        epoch = start_epoch - 1
        if str(epoch +
               1) in all_stats.keys() and args.skip_existing_validations:
            print('WARNING: epoch {} already validated, skipping...'.format(
                epoch + 1))
            return

        val_loss = do_validate(model, valid_loader, criterion, scorers, vocab,
                               teacher_p, args, params, stats, epoch,
                               sc_activated, gts_sc_valid)
        all_stats[str(epoch + 1)] = stats
        save_stats(args, params, all_stats, postfix=stats_postfix)
    else:
        for epoch in range(start_epoch, args.num_epochs):
            stats = {}
            begin = datetime.now()

            total_loss = 0

            if params.hierarchical_model:
                total_loss_sent = 0
                total_loss_word = 0

            num_batches = 0
            vocab_counts = {
                'cnt': 0,
                'max': 0,
                'min': 9999,
                'sum': 0,
                'unk_cnt': 0,
                'unk_sum': 0
            }

            # If start self critical training
            if not sc_activated and sc_will_happen and epoch >= args.self_critical_from_epoch:
                if all_stats:
                    best_ep, best_cider = max(
                        [(ep, all_stats[ep]['validation_cider'])
                         for ep in all_stats],
                        key=lambda x: x[1])
                    print('Loading model from epoch', best_ep,
                          'which has the better score with', best_cider)
                    state = torch.load(
                        get_model_path(args, params, int(best_ep)))
                    model = EncoderDecoder(params,
                                           device,
                                           len(vocab),
                                           state,
                                           ef_dims,
                                           lr_dict=lr_dict)
                    opt_params = model.get_opt_params()

                optimizer = torch.optim.Adam(opt_params,
                                             lr=5e-5,
                                             weight_decay=args.weight_decay)
                criterion = rl_criterion
                print('Self-critical loss training begins')
                sc_activated = True

            for i, data in enumerate(data_loader):

                if params.hierarchical_model:
                    (images, captions, lengths, image_ids, features,
                     sorting_order, last_sentence_indicator) = data
                    sorting_order = sorting_order.to(device)
                else:
                    (images, captions, lengths, image_ids, features) = data

                if epoch == 0:
                    unk = vocab('<unk>')
                    for j in range(captions.shape[0]):
                        # Flatten the caption in case it's a paragraph
                        # this is harmless for regular captions too:
                        xl = captions[j, :].view(-1)
                        xw = xl > unk
                        xu = xl == unk
                        xwi = sum(xw).item()
                        xui = sum(xu).item()
                        vocab_counts['cnt'] += 1
                        vocab_counts['sum'] += xwi
                        vocab_counts['max'] = max(vocab_counts['max'], xwi)
                        vocab_counts['min'] = min(vocab_counts['min'], xwi)
                        vocab_counts['unk_cnt'] += xui > 0
                        vocab_counts['unk_sum'] += xui
                # Set mini-batch dataset
                images = images.to(device)
                captions = captions.to(device)

                # Remove <start> token from targets if we are initializing the RNN
                # hidden state from image features:
                if params.rnn_hidden_init == 'from_features' and not params.hierarchical_model:
                    # Subtract one from all lengths to match new target lengths:
                    lengths = [x - 1 if x > 0 else x for x in lengths]
                    targets = pack_padded_sequence(captions[:, 1:],
                                                   lengths,
                                                   batch_first=True)[0]
                else:
                    if params.hierarchical_model:
                        targets = prepare_hierarchical_targets(
                            last_sentence_indicator, args.max_sentences,
                            lengths, captions, device)
                    else:
                        targets = pack_padded_sequence(captions,
                                                       lengths,
                                                       batch_first=True)[0]
                        sorting_order = None

                init_features = features[0].to(device) if len(
                    features) > 0 and features[0] is not None else None
                persist_features = features[1].to(device) if len(
                    features) > 1 and features[1] is not None else None

                # Forward, backward and optimize
                # Calculate the probability whether to use teacher forcing or not:

                # Iterate over batches:
                iteration = (epoch - start_epoch) * len(data_loader) + i

                teacher_p = get_teacher_prob(args.teacher_forcing_k, iteration,
                                             args.teacher_forcing_beta)

                # Allow model to log values at the last batch of the epoch
                writer_data = None
                if writer and (i == len(data_loader) - 1
                               or i == args.num_batches - 1):
                    writer_data = {'writer': writer, 'epoch': epoch + 1}

                sample_len = captions.size(1) if args.self_critical_loss in [
                    'mixed', 'mixed_with_face'
                ] else 20
                if sc_activated:
                    sampled_seq, sampled_log_probs, outputs = model.sample(
                        images,
                        init_features,
                        persist_features,
                        max_seq_length=sample_len,
                        start_token_id=vocab('<start>'),
                        trigram_penalty_alpha=args.trigram_penalty_alpha,
                        stochastic_sampling=True,
                        output_logprobs=True,
                        output_outputs=True)
                    sampled_seq = model.decoder.alt_prob_to_tensor(
                        sampled_seq, device=device)
                else:
                    outputs = model(images,
                                    init_features,
                                    captions,
                                    lengths,
                                    persist_features,
                                    teacher_p,
                                    args.teacher_forcing,
                                    sorting_order,
                                    writer_data=writer_data)

                if args.share_embedding_weights:
                    # Weights of (HxH) projection matrix used for regularizing
                    # models that share embedding weights
                    projection = model.decoder.projection.weight
                    loss = criterion(projection, outputs, targets)
                elif sc_activated:
                    # get greedy decoding baseline
                    model.eval()
                    with torch.no_grad():
                        greedy_sampled_seq = model.sample(
                            images,
                            init_features,
                            persist_features,
                            max_seq_length=sample_len,
                            start_token_id=vocab('<start>'),
                            trigram_penalty_alpha=args.trigram_penalty_alpha,
                            stochastic_sampling=False)
                        greedy_sampled_seq = model.decoder.alt_prob_to_tensor(
                            greedy_sampled_seq, device=device)
                    model.train()

                    if args.self_critical_loss in [
                            'sc', 'sc_with_diversity',
                            'sc_with_relative_diversity',
                            'sc_with_bleu_diversity', 'sc_with_repetition'
                    ]:
                        loss, advantage = criterion(
                            sampled_seq,
                            sampled_log_probs,
                            greedy_sampled_seq, [gts_sc[i] for i in image_ids],
                            scorers,
                            vocab,
                            return_advantage=True)
                    elif args.self_critical_loss in ['mixed']:
                        loss, advantage = criterion(
                            sampled_seq,
                            sampled_log_probs,
                            outputs,
                            greedy_sampled_seq, [gts_sc[i] for i in image_ids],
                            scorers,
                            vocab,
                            targets,
                            lengths,
                            gamma_ml_rl=args.gamma_ml_rl,
                            return_advantage=True)
                    elif args.self_critical_loss in ['mixed_with_face']:
                        loss, advantage = criterion(
                            sampled_seq,
                            sampled_log_probs,
                            outputs,
                            greedy_sampled_seq, [gts_sc[i] for i in image_ids],
                            scorers,
                            vocab,
                            captions,
                            targets,
                            lengths,
                            gamma_ml_rl=args.gamma_ml_rl,
                            return_advantage=True)
                    else:
                        raise ValueError('Invalid self-critical loss')

                    if writer is not None and i % 100 == 0:
                        writer.add_scalar('training_loss', loss.item(),
                                          epoch * len(data_loader) + i)
                        writer.add_scalar('advantage', advantage,
                                          epoch * len(data_loader) + i)
                        writer.add_scalar('lr',
                                          optimizer.param_groups[0]['lr'],
                                          epoch * len(data_loader) + i)
                else:
                    loss = criterion(outputs, targets)

                model.zero_grad()
                loss.backward()

                # Clip gradients if desired:
                if args.grad_clip is not None:
                    # grad_norms = [x.grad.data.norm(2) for x in opt_params]
                    # batch_max_grad = np.max(grad_norms)
                    # if batch_max_grad > 10.0:
                    #     print('WARNING: gradient norms larger than 10.0')

                    # torch.nn.utils.clip_grad_norm_(decoder.parameters(), 0.1)
                    # torch.nn.utils.clip_grad_norm_(encoder.parameters(), 0.1)
                    clip_gradients(optimizer, args.grad_clip)

                # Update weights:
                optimizer.step()

                # CyclicalLR requires us to update LR at every minibatch:
                if args.lr_scheduler == 'CyclicalLR':
                    scheduler.step()

                total_loss += loss.item()

                num_batches += 1

                if params.hierarchical_model:
                    _, loss_sent, _, loss_word = criterion.item_terms()
                    total_loss_sent += float(loss_sent)
                    total_loss_word += float(loss_word)

                # Print log info
                if (i + 1) % args.log_step == 0:
                    print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, '
                          'Perplexity: {:5.4f}'.format(epoch + 1,
                                                       args.num_epochs, i + 1,
                                                       total_step, loss.item(),
                                                       np.exp(loss.item())))
                    sys.stdout.flush()

                    if params.hierarchical_model:
                        weight_sent, loss_sent, weight_word, loss_word = criterion.item_terms(
                        )
                        print('Sentence Loss: {:.4f}, '
                              'Word Loss: {:.4f}'.format(
                                  float(loss_sent), float(loss_word)))
                        sys.stdout.flush()

                if i + 1 == args.num_batches:
                    break

            end = datetime.now()

            stats['training_loss'] = total_loss / num_batches

            if params.hierarchical_model:
                stats['loss_sentence'] = total_loss_sent / num_batches
                stats['loss_word'] = total_loss_word / num_batches

            print('Epoch {} duration: {}, average loss: {:.4f}'.format(
                epoch + 1, end - begin, stats['training_loss']))

            save_model(args, params, model.encoder, model.decoder, optimizer,
                       epoch, vocab)

            if epoch == 0:
                vocab_counts['avg'] = vocab_counts['sum'] / vocab_counts['cnt']
                vocab_counts['unk_cnt_per'] = 100 * vocab_counts[
                    'unk_cnt'] / vocab_counts['cnt']
                vocab_counts['unk_sum_per'] = 100 * vocab_counts[
                    'unk_sum'] / vocab_counts['sum']
                # print(vocab_counts)
                print((
                    'Training data contains {sum} words in {cnt} captions (avg. {avg:.1f} w/c)'
                    + ' with {unk_sum} <unk>s ({unk_sum_per:.1f}%)' +
                    ' in {unk_cnt} ({unk_cnt_per:.1f}%) captions').format(
                        **vocab_counts))

            ############################################
            # Validation loss and learning rate update #
            ############################################

            if args.validate is not None and (epoch +
                                              1) % args.validation_step == 0:
                val_loss = do_validate(model, valid_loader, criterion, scorers,
                                       vocab, teacher_p, args, params, stats,
                                       epoch, sc_activated, gts_sc_valid)

                if args.lr_scheduler == 'ReduceLROnPlateau':
                    scheduler.step(val_loss)
            elif args.lr_scheduler == 'StepLR':
                scheduler.step()

            all_stats[str(epoch + 1)] = stats
            save_stats(args, params, all_stats, writer=writer)

            if writer is not None:
                # Log model data to tensorboard
                log_model_data(params, model, epoch + 1, writer)

    if writer is not None:
        writer.close()