class TransformingAutoEncoderController(nn.Module, ControllableModel):
    '''
    Transforming Auto-Encoder.

    References:
        - Bahdanau, D., Cho, K., & Bengio, Y. (2014). Neural machine translation by jointly learning to align and translate. arXiv preprint arXiv:1409.0473.
        - Floridi, L., & Chiriatti, M. (2020). GPT-3: Its nature, scope, limits, and consequences. Minds and Machines, 30(4), 681-694.
        - Miller, A., Fisch, A., Dodge, J., Karimi, A. H., Bordes, A., & Weston, J. (2016). Key-value memory networks for directly reading documents. arXiv preprint arXiv:1606.03126.
        - Radford, A., Narasimhan, K., Salimans, T., & Sutskever, I. (2018) Improving Language Understanding by Generative Pre-Training. OpenAI (URL: https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf)
        - Radford, A., Wu, J., Child, R., Luan, D., Amodei, D., & Sutskever, I. (2019). Language models are unsupervised multitask learners. OpenAI blog, 1(8), 9.
        - Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., & Polosukhin, I. (2017). Attention is all you need. arXiv preprint arXiv:1706.03762.

    '''
    __loaded_filename = None
    __loaded_ctx = None

    # `bool` that means initialization in this class will be deferred or not.
    __init_deferred_flag = False

    def __init__(
        self,
        computable_loss=None,
        optimizer_f=None,
        encoder=None,
        decoder=None,
        reconstructor=None,
        layer_n=3,
        head_n=3,
        seq_len=5,
        depth_dim=100,
        hidden_dim=100,
        self_attention_activation_list=[],
        multi_head_attention_activation_list=[],
        fc_activation_list=[],
        learning_rate=1e-05,
        weight_decay=0.01,
        dropout_rate=0.5,
        ctx="cpu",
        regularizatable_data_list=[],
        not_init_flag=False,
    ):
        '''
        Init.

        Args:
            computable_loss:                is-a `ComputableLoss` or `gluon.loss`.
            encoder:                        is-a `TransformerModel`.
            decoder:                        is-a `TransformerModel`.
            reconstructor:                  is-a `TransformerModel`.
            layer_n:                        `int` of the number of layers.
            head_n:                         `int` of the number of heads for multi-head attention model.
            seq_len:                        `int` of the length of sequences.
            depth_dim:                      `int` of dimension of dense layer.
            hidden_dim:                     `int` of dimension of hidden(encoder) layer.
            self_attention_activation_list: `list` of `str` of activation function for self-attention model.
            multi_head_attention_activation_list:   `list` of `str` of activation function for multi-head attention model.
            fc_activation_list:             `list` of `str` of activation function in fully-connected layers.
            learning_rate:                  `float` of learning rate.
            learning_attenuate_rate:        `float` of attenuate the `learning_rate` by a factor of this value every `attenuate_epoch`.
            attenuate_epoch:                `int` of attenuate the `learning_rate` by a factor of `learning_attenuate_rate` every `attenuate_epoch`.
            optimizer_name:                 `str` of name of optimizer.
            hybridize_flag:                  Call `mxnet.gluon.HybridBlock.hybridize()` or not.
            scale:                          `float` of scaling factor for initial parameters.
            ctx:                            `mx.cpu()` or `mx.gpu()`.
            initializer:                    is-a `mxnet.initializer` for parameters of model. If `None`, it is drawing from the Xavier distribution.

        '''
        super(TransformingAutoEncoderController, self).__init__()

        if computable_loss is None:
            computable_loss = nn.CrossEntropyLoss()
        self.__computable_loss = computable_loss

        if encoder is None:
            if hidden_dim is None or hidden_dim == depth_dim:
                encoder = TransformerEncoder(
                    depth_dim=depth_dim,
                    layer_n=layer_n,
                    head_n=head_n,
                    self_attention_activation_list=
                    self_attention_activation_list,
                    fc_activation_list=fc_activation_list,
                    computable_loss=computable_loss,
                    not_init_flag=not_init_flag,
                    dropout_rate=dropout_rate,
                    ctx=ctx)
            else:
                encoder = TransformerEncoder(
                    depth_dim=hidden_dim,
                    layer_n=layer_n,
                    head_n=head_n,
                    self_attention_activation_list=
                    self_attention_activation_list,
                    fc_activation_list=fc_activation_list,
                    computable_loss=computable_loss,
                    not_init_flag=not_init_flag,
                    dropout_rate=dropout_rate,
                    ctx=ctx)
            encoder.embedding_flag = False
        else:
            if isinstance(encoder, TransformerModel) is False:
                raise TypeError(
                    "The type of `encoder` must be `TransformerModel`.")

        if decoder is None:
            if hidden_dim is None or hidden_dim == depth_dim:
                decoder = TransformerDecoder(
                    head_n=head_n,
                    depth_dim=depth_dim,
                    layer_n=layer_n,
                    self_attention_activation_list=
                    self_attention_activation_list,
                    multi_head_attention_activation_list=
                    multi_head_attention_activation_list,
                    fc_activation_list=fc_activation_list,
                    computable_loss=computable_loss,
                    not_init_flag=not_init_flag,
                    ctx=ctx)
            else:
                decoder = TransformerDecoder(
                    head_n=head_n,
                    depth_dim=hidden_dim,
                    output_dim=hidden_dim,
                    layer_n=layer_n,
                    self_attention_activation_list=
                    self_attention_activation_list,
                    multi_head_attention_activation_list=
                    multi_head_attention_activation_list,
                    fc_activation_list=fc_activation_list,
                    computable_loss=computable_loss,
                    not_init_flag=not_init_flag,
                    ctx=ctx)
            decoder.embedding_flag = False
        else:
            if isinstance(decoder, TransformerModel) is False:
                raise TypeError(
                    "The type of `decoder` must be `TransformerModel`.")

        if reconstructor is None:
            if hidden_dim is None or hidden_dim == depth_dim:
                reconstructor = TransformerReconstructor(
                    head_n=head_n,
                    depth_dim=depth_dim,
                    layer_n=layer_n,
                    self_attention_activation_list=
                    self_attention_activation_list,
                    fc_activation_list=fc_activation_list,
                    computable_loss=computable_loss,
                    not_init_flag=not_init_flag,
                    ctx=ctx,
                )
            else:
                reconstructor = TransformerReconstructor(
                    head_n=head_n,
                    depth_dim=hidden_dim,
                    output_dim=depth_dim,
                    layer_n=layer_n,
                    self_attention_activation_list=
                    self_attention_activation_list,
                    fc_activation_list=fc_activation_list,
                    computable_loss=computable_loss,
                    not_init_flag=not_init_flag,
                    ctx=ctx,
                )
            reconstructor.embedding_flag = False
        else:
            if isinstance(reconstructor, TransformerModel) is False:
                raise TypeError(
                    "The type of `reconstructor` must be `TransformerModel`.")

        logger = getLogger("accelbrainbase")
        self.logger = logger

        self.encoder = encoder
        self.decoder = decoder
        self.reconstructor = reconstructor

        if hidden_dim is not None and hidden_dim != depth_dim:
            self.encoder_hidden_fc = nn.Linear(
                depth_dim,
                hidden_dim,
                bias=True,
            )
            self.decoder_hidden_fc = nn.Linear(
                depth_dim,
                hidden_dim,
                bias=True,
            )
            init_flag = True
        else:
            self.encoder_hidden_fc = None
            self.decoder_hidden_fc = None
            init_flag = False

        self.__ctx = ctx
        self.to(self.__ctx)

        if self.init_deferred_flag is False:
            if not_init_flag is False:
                if optimizer_f is not None:
                    if init_flag is True:
                        self.optimizer = optimizer_f(self.parameters(), )
                else:
                    if init_flag is True:
                        self.optimizer = AdamW(self.parameters(),
                                               lr=learning_rate,
                                               weight_decay=weight_decay)

        for v in regularizatable_data_list:
            if isinstance(v, RegularizatableData) is False:
                raise TypeError(
                    "The type of values of `regularizatable_data_list` must be `RegularizatableData`."
                )
        self.__regularizatable_data_list = regularizatable_data_list

        self.__learning_rate = learning_rate
        self.__weight_decay = weight_decay
        self.seq_len = seq_len
        self.epoch = 0

    def learn(self, iteratable_data):
        '''
        Learn samples drawn by `IteratableData.generate_learned_samples()`.

        Args:
            iteratable_data:     is-a `TransformerIterator`.
        '''
        if isinstance(iteratable_data, TransformerIterator) is False:
            raise TypeError(
                "The type of `iteratable_data` must be `TransformerIterator`.")

        self.__loss_list = []
        learning_rate = self.__learning_rate

        try:
            epoch = self.epoch
            iter_n = 0
            for encoded_observed_arr, decoded_observed_arr, encoded_mask_arr, decoded_mask_arr, test_encoded_observed_arr, test_decoded_observed_arr, test_encoded_mask_arr, test_decoded_mask_arr, training_target_arr, test_target_arr in iteratable_data.generate_learned_samples(
            ):
                self.epoch = epoch
                if self.encoder.optimizer is not None and self.decoder.optimizer is not None:
                    optimizer_setup_flag = True
                    self.encoder.optimizer.zero_grad()
                    self.decoder.optimizer.zero_grad()
                    self.optimizer.zero_grad()
                else:
                    optimizer_setup_flag = False

                pred_arr = self.inference(encoded_observed_arr,
                                          decoded_observed_arr,
                                          encoded_mask_arr, decoded_mask_arr)
                loss = self.compute_loss(pred_arr, training_target_arr)
                if optimizer_setup_flag is False:
                    self.encoder.optimizer.zero_grad()
                    self.decoder.optimizer.zero_grad()
                    self.optimizer.zero_grad()
                    pred_arr = self.inference(encoded_observed_arr,
                                              decoded_observed_arr,
                                              encoded_mask_arr,
                                              decoded_mask_arr)
                    loss = self.compute_loss(pred_arr, training_target_arr)

                loss.backward()
                self.optimizer.step()
                self.decoder.optimizer.step()
                self.encoder.optimizer.step()
                self.regularize()
                self.decoder.regularize()
                self.encoder.regularize()

                if (iter_n + 1) % int(
                        iteratable_data.iter_n / iteratable_data.epochs) == 0:
                    if torch.inference_mode():
                        test_pred_arr = self.inference(
                            test_encoded_observed_arr,
                            test_decoded_observed_arr, test_encoded_mask_arr,
                            test_decoded_mask_arr)

                        test_loss = self.compute_loss(test_pred_arr,
                                                      test_target_arr)

                    _loss = loss.to('cpu').detach().numpy().copy()
                    _test_loss = test_loss.to('cpu').detach().numpy().copy()

                    self.__loss_list.append((_loss, _test_loss))
                    self.logger.debug("Epochs: " + str(epoch + 1) +
                                      " Train loss: " + str(_loss) +
                                      " Test loss: " + str(_test_loss))
                    epoch += 1
                iter_n += 1

        except KeyboardInterrupt:
            self.logger.debug("Interrupt.")

        self.logger.debug("end. ")
        self.epoch = epoch

    def inference(
        self,
        encoded_observed_arr,
        decoded_observed_arr,
        encoded_mask_arr=None,
        decoded_mask_arr=None,
    ):
        '''
        Inference samples drawn by `IteratableData.generate_inferenced_samples()`.

        Args:
            encoded_observed_arr:   rank-3 Array like or sparse matrix as the observed data points.
                                    The shape is: (batch size, the length of sequence, feature points)

            decoded_observed_arr:   rank-3 Array like or sparse matrix as the observed data points.
                                    The shape is: (batch size, the length of sequence, feature points)

            encoded_mask_arr:       rank-3 Array like or sparse matrix as the observed data points.
                                    The shape is: (batch size, the length of sequence, feature points)

            decoded_mask_arr:       rank-3 Array like or sparse matrix as the observed data points.
                                    The shape is: (batch size, the length of sequence, feature points)

        Returns:
            `mxnet.ndarray` of inferenced feature points.
        '''
        return self(
            encoded_observed_arr,
            decoded_observed_arr,
            encoded_mask_arr,
            decoded_mask_arr,
        )

    def compute_loss(self, pred_arr, labeled_arr):
        '''
        Compute loss.

        Args:
            pred_arr:       `mxnet.ndarray` or `mxnet.symbol`.
            labeled_arr:    `mxnet.ndarray` or `mxnet.symbol`.

        Returns:
            loss.
        '''
        return self.__computable_loss(pred_arr, labeled_arr)

    def forward(
        self,
        encoded_observed_arr,
        decoded_observed_arr,
        encoded_mask_arr=None,
        decoded_mask_arr=None,
    ):
        '''
        Hybrid forward with Gluon API.

        Args:
            F:                      `mxnet.ndarray` or `mxnet.symbol`.
            encoded_observed_arr:   rank-3 Array like or sparse matrix as the observed data points.
                                    The shape is: (batch size, the length of sequence, feature points)

            decoded_observed_arr:   rank-3 Array like or sparse matrix as the observed data points.
                                    The shape is: (batch size, the length of sequence, feature points)

            encoded_mask_arr:       rank-3 Array like or sparse matrix as the observed data points.
                                    The shape is: (batch size, the length of sequence, feature points)

            decoded_mask_arr:       rank-3 Array like or sparse matrix as the observed data points.
                                    The shape is: (batch size, the length of sequence, feature points)
        
        Returns:
            `mxnet.ndarray` or `mxnet.symbol` of inferenced feature points.
        '''
        if self.__loaded_filename is not None:
            loaded_filename = self.__loaded_filename
            self.__loaded_filename = None
            init_encoded_observed_arr = encoded_observed_arr.detach()
            init_decoded_observed_arr = decoded_observed_arr.detach()
            if encoded_mask_arr is not None:
                init_encoded_mask_arr = encoded_mask_arr.detach()
            else:
                init_encoded_mask_arr = None
            if decoded_mask_arr is not None:
                init_decoded_mask_arr = decoded_mask_arr.detach()
            else:
                init_decoded_mask_arr = decoded_mask_arr

            _ = self.forward(
                init_encoded_observed_arr,
                init_decoded_observed_arr,
                init_encoded_mask_arr,
                init_decoded_mask_arr,
            )
            self.load_parameters(loaded_filename, ctx=self.__loaded_ctx)
            self.__loaded_ctx = None

        if encoded_mask_arr is None:
            encoded_mask_arr = torch.ones(
                (encoded_observed_arr.shape[0], 1, 1, 1), )
            encoded_mask_arr = encoded_mask_arr.to(encoded_observed_arr.device)
        if decoded_mask_arr is None:
            decoded_mask_arr = torch.ones(
                (decoded_observed_arr.shape[0], 1, 1, 1), )
            decoded_mask_arr = decoded_mask_arr.to(decoded_observed_arr.device)

        if self.encoder_hidden_fc is not None:
            encoded_observed_arr = self.encoder_hidden_fc(encoded_observed_arr)
        if self.decoder_hidden_fc is not None:
            decoded_observed_arr = self.decoder_hidden_fc(decoded_observed_arr)

        encoded_arr = self.encoder(encoded_observed_arr, encoded_mask_arr)
        decoded_arr = self.decoder(
            decoded_observed_arr,
            encoded_arr,
            decoded_mask_arr,
            encoded_mask_arr,
        )

        self.feature_points_arr = decoded_arr

        reconstructed_arr = self.reconstructor(decoded_arr, None)

        return reconstructed_arr

    def extract_learned_dict(self):
        '''
        Extract (pre-) learned parameters.

        Returns:
            `dict` of the parameters.
        '''
        params_dict = {}
        for k in self.state_dict().keys():
            params_dict.setdefault(k, self.state_dict()[k])

        return params_dict

    def regularize(self):
        '''
        Regularization.
        '''
        if len(self.__regularizatable_data_list) > 0:
            params_dict = self.extract_learned_dict()
            for regularizatable in self.__regularizatable_data_list:
                params_dict = regularizatable.regularize(params_dict)

            for k, params in params_dict.items():
                self.load_state_dict({k: params}, strict=False)

    def __rename_file(self, filename):
        filename_list = filename.split(".")
        _format = filename_list[-1]
        encoder_filename = filename.replace("." + _format,
                                            "_encoder." + _format)
        decoder_filename = filename.replace("." + _format,
                                            "_decoder." + _format)
        reconstructor_filename = filename.replace("." + _format,
                                                  "_reconstructor." + _format)
        return encoder_filename, decoder_filename, reconstructor_filename

    def save_parameters(self, filename):
        '''
        Save parameters to files.

        Args:
            filename:       File name.
        '''
        encoder_filename, decoder_filename, reconstructor_filename = self.__rename_file(
            filename)

        self.encoder.epoch = self.epoch
        self.encoder.loss_arr = self.loss_arr
        self.encoder.save_parameters(encoder_filename)

        self.decoder.epoch = self.epoch
        self.decoder.loss_arr = self.loss_arr
        self.decoder.save_parameters(decoder_filename)

        self.reconstructor.epoch = self.epoch
        self.reconstructor.loss_arr = self.loss_arr
        self.reconstructor.save_parameters(decoder_filename)

        torch.save(
            {
                'model_state_dict': self.state_dict(),
                'optimizer_state_dict': self.optimizer.state_dict(),
                'epoch': self.epoch,
                'loss': self.loss_arr,
            }, filename)

    def load_parameters(self, filename, ctx=None, strict=True):
        '''
        Load parameters to files.

        Args:
            filename:       File name.
            ctx:            Context-manager that changes the selected device.
            strict:         Whether to strictly enforce that the keys in state_dict match the keys returned by this module’s state_dict() function. Default: `True`.
        '''
        try:
            encoder_filename, decoder_filename, reconstructor_filename = self.__rename_file(
                filename)
            self.encoder.load_parameters(encoder_filename,
                                         ctx=ctx,
                                         strict=strict)
            self.decoder.load_parameters(decoder_filename,
                                         ctx=ctx,
                                         strict=strict)
            self.reconstructor.load_parameters(reconstructor_filename,
                                               ctx=ctx,
                                               strict=strict)

            checkpoint = torch.load(filename)
            self.load_state_dict(checkpoint['model_state_dict'], strict=strict)
            self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            self.epoch = checkpoint['epoch']
            self.__loss_list = checkpoint['loss'].tolist()
        except RuntimeError:
            self.__loaded_filename = filename
            self.__loaded_ctx = ctx

        if ctx is not None:
            self.to(ctx)
            self.encoder.to(ctx)
            self.decoder.to(ctx)
            self.reconstructor.to(ctx)
            self.__ctx = ctx

    def set_readonly(self, value):
        ''' setter '''
        raise TypeError("This property must be read-only.")

    def get_init_deferred_flag(self):
        ''' getter for `bool` that means initialization in this class will be deferred or not.'''
        return self.__init_deferred_flag

    def set_init_deferred_flag(self, value):
        ''' setter for `bool` that means initialization in this class will be deferred or not.'''
        self.__init_deferred_flag = value

    init_deferred_flag = property(get_init_deferred_flag,
                                  set_init_deferred_flag)

    __loss_list = []

    def get_loss_arr(self):
        ''' getter '''
        return np.array(self.__loss_list)

    def set_loss_arr(self, value):
        ''' setter '''
        raise TypeError("This property must be read-only.")

    loss_arr = property(get_loss_arr, set_loss_arr)
예제 #2
0
def train():
    global writer
    # For parsing commandline arguments
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--dataset_root",
        type=str,
        required=True,
        help='path to dataset folder containing train-test-validation folders')
    parser.add_argument("--checkpoint_dir",
                        type=str,
                        required=True,
                        help='path to folder for saving checkpoints')
    parser.add_argument("--checkpoint",
                        type=str,
                        help='path of checkpoint for pretrained model')
    parser.add_argument(
        "--train_continue",
        type=bool,
        default=False,
        help=
        'If resuming from checkpoint, set to True and set `checkpoint` path. Default: False.'
    )
    parser.add_argument("--epochs",
                        type=int,
                        default=200,
                        help='number of epochs to train. Default: 200.')
    parser.add_argument("--train_batch_size",
                        type=int,
                        default=3,
                        help='batch size for training. Default: 6.')
    parser.add_argument("--validation_batch_size",
                        type=int,
                        default=6,
                        help='batch size for validation. Default: 10.')
    parser.add_argument("--init_learning_rate",
                        type=float,
                        default=0.0001,
                        help='set initial learning rate. Default: 0.0001.')
    parser.add_argument(
        "--milestones",
        type=list,
        default=[25, 50],
        help=
        'UNUSED NOW: Set to epoch values where you want to decrease learning rate by a factor of 0.1. Default: [100, 150]'
    )
    parser.add_argument(
        "--progress_iter",
        type=int,
        default=200,
        help=
        'frequency of reporting progress and validation. N: after every N iterations. Default: 100.'
    )
    parser.add_argument(
        "--checkpoint_epoch",
        type=int,
        default=5,
        help=
        'checkpoint saving frequency. N: after every N epochs. Each checkpoint is roughly of size 151 MB.Default: 5.'
    )
    args = parser.parse_args()

    ##[TensorboardX](https://github.com/lanpa/tensorboardX)
    ### For visualizing loss and interpolated frames

    ###Initialize flow computation and arbitrary-time flow interpolation CNNs.

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    print(device)
    flowComp = model.UNet(6, 4)
    flowComp.to(device)
    ArbTimeFlowIntrp = model.UNet(20, 5)
    ArbTimeFlowIntrp.to(device)

    ###Initialze backward warpers for train and validation datasets

    train_W_dim = 352
    train_H_dim = 352

    trainFlowBackWarp = model.backWarp(train_W_dim, train_H_dim, device)
    trainFlowBackWarp = trainFlowBackWarp.to(device)
    validationFlowBackWarp = model.backWarp(train_W_dim * 2, train_H_dim,
                                            device)
    validationFlowBackWarp = validationFlowBackWarp.to(device)

    ###Load Datasets

    # Channel wise mean calculated on custom training dataset
    # mean = [0.43702903766008444, 0.43715053433990597, 0.40436416782660994]
    mean = [0.5] * 3
    std = [1, 1, 1]
    normalize = transforms.Normalize(mean=mean, std=std)
    transform = transforms.Compose([transforms.ToTensor(), normalize])

    trainset = dataloader.SuperSloMo(root=args.dataset_root + '/train',
                                     randomCropSize=(train_W_dim, train_H_dim),
                                     transform=transform,
                                     train=True)
    trainloader = torch.utils.data.DataLoader(trainset,
                                              batch_size=args.train_batch_size,
                                              shuffle=True,
                                              num_workers=2,
                                              pin_memory=True)

    validationset = dataloader.SuperSloMo(
        root=args.dataset_root + '/validation',
        transform=transform,
        randomCropSize=(2 * train_W_dim, train_H_dim),
        train=False)
    validationloader = torch.utils.data.DataLoader(
        validationset,
        batch_size=args.validation_batch_size,
        shuffle=False,
        num_workers=2,
        pin_memory=True)

    print(trainset, validationset)

    ###Create transform to display image from tensor

    negmean = [x * -1 for x in mean]
    revNormalize = transforms.Normalize(mean=negmean, std=std)
    TP = transforms.Compose([revNormalize, transforms.ToPILImage()])

    ###Utils

    def get_lr(optimizer):
        for param_group in optimizer.param_groups:
            return param_group['lr']

    ###Loss and Optimizer

    L1_lossFn = nn.L1Loss()
    MSE_LossFn = nn.MSELoss()

    if args.train_continue:
        dict1 = torch.load(args.checkpoint)
        last_epoch = dict1['epoch'] * len(trainloader)
    else:
        last_epoch = -1

    params = list(ArbTimeFlowIntrp.parameters()) + list(flowComp.parameters())

    optimizer = AdamW(params, lr=args.init_learning_rate, amsgrad=True)
    # optimizer = optim.SGD(params, lr=args.init_learning_rate, momentum=0.9, nesterov=True)

    # scheduler to decrease learning rate by a factor of 10 at milestones.
    # Patience suggested value:
    # patience = number of item in train dataset / train_batch_size * (Number of epochs patience)
    # It does say epoch, but in this case, the number of progress iterations is what's really being worked with.
    # As such, each epoch will be given by the above formula (roughly, if using a rough dataset count)
    # If the model seems to equalize fast, reduce the number of epochs accordingly.

    # scheduler = optim.lr_scheduler.CyclicLR(optimizer,
    #                                         base_lr=1e-8,
    #                                         max_lr=9.0e-3,
    #                                         step_size_up=3500,
    #                                         mode='triangular2',
    #                                         cycle_momentum=False,
    #                                         last_epoch=last_epoch)

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=0.1,
        patience=len(trainloader) * 3,
        cooldown=len(trainloader) * 2,
        verbose=True,
        min_lr=1e-8)

    # Changed to use this to ensure a more adaptive model.
    # The changed model used here seems to converge or plateau faster with more rapid swings over time.
    # As such letting the model deal with stagnation more proactively than at a set stage seems more useful.

    ###Initializing VGG16 model for perceptual loss

    vgg16 = torchvision.models.vgg16(pretrained=True)
    vgg16_conv_4_3 = nn.Sequential(*list(vgg16.children())[0][:22])
    vgg16_conv_4_3.to(device)

    for param in vgg16_conv_4_3.parameters():
        param.requires_grad = False

    # Validation function

    def validate():
        # For details see training.
        psnr = 0
        tloss = 0
        flag = 1
        with torch.no_grad():
            for validationIndex, (validationData,
                                  validationFrameIndex) in enumerate(
                                      validationloader, 0):
                frame0, frameT, frame1 = validationData

                I0 = frame0.to(device)
                I1 = frame1.to(device)
                IFrame = frameT.to(device)

                torch.cuda.empty_cache()
                flowOut = flowComp(torch.cat((I0, I1), dim=1))
                F_0_1 = flowOut[:, :2, :, :]
                F_1_0 = flowOut[:, 2:, :, :]

                fCoeff = model.getFlowCoeff(validationFrameIndex, device)
                torch.cuda.empty_cache()
                F_t_0 = fCoeff[0] * F_0_1 + fCoeff[1] * F_1_0
                F_t_1 = fCoeff[2] * F_0_1 + fCoeff[3] * F_1_0

                g_I0_F_t_0 = validationFlowBackWarp(I0, F_t_0)
                g_I1_F_t_1 = validationFlowBackWarp(I1, F_t_1)
                torch.cuda.empty_cache()
                intrpOut = ArbTimeFlowIntrp(
                    torch.cat((I0, I1, F_0_1, F_1_0, F_t_1, F_t_0, g_I1_F_t_1,
                               g_I0_F_t_0),
                              dim=1))

                F_t_0_f = intrpOut[:, :2, :, :] + F_t_0
                F_t_1_f = intrpOut[:, 2:4, :, :] + F_t_1
                V_t_0 = torch.sigmoid(intrpOut[:, 4:5, :, :])
                V_t_1 = 1 - V_t_0
                # torch.cuda.empty_cache()
                g_I0_F_t_0_f = validationFlowBackWarp(I0, F_t_0_f)
                g_I1_F_t_1_f = validationFlowBackWarp(I1, F_t_1_f)

                wCoeff = model.getWarpCoeff(validationFrameIndex, device)
                torch.cuda.empty_cache()
                Ft_p = (wCoeff[0] * V_t_0 * g_I0_F_t_0_f + wCoeff[1] * V_t_1 *
                        g_I1_F_t_1_f) / (wCoeff[0] * V_t_0 + wCoeff[1] * V_t_1)

                # For tensorboard
                if (flag):
                    retImg = torchvision.utils.make_grid([
                        revNormalize(frame0[0]),
                        revNormalize(frameT[0]),
                        revNormalize(Ft_p.cpu()[0]),
                        revNormalize(frame1[0])
                    ],
                                                         padding=10)
                    flag = 0

                # loss
                recnLoss = L1_lossFn(Ft_p, IFrame)
                # torch.cuda.empty_cache()
                prcpLoss = MSE_LossFn(vgg16_conv_4_3(Ft_p),
                                      vgg16_conv_4_3(IFrame))

                warpLoss = L1_lossFn(g_I0_F_t_0, IFrame) + L1_lossFn(
                    g_I1_F_t_1, IFrame) + L1_lossFn(
                        validationFlowBackWarp(I0, F_1_0), I1) + L1_lossFn(
                            validationFlowBackWarp(I1, F_0_1), I0)
                torch.cuda.empty_cache()
                loss_smooth_1_0 = torch.mean(
                    torch.abs(F_1_0[:, :, :, :-1] -
                              F_1_0[:, :, :, 1:])) + torch.mean(
                                  torch.abs(F_1_0[:, :, :-1, :] -
                                            F_1_0[:, :, 1:, :]))
                loss_smooth_0_1 = torch.mean(
                    torch.abs(F_0_1[:, :, :, :-1] -
                              F_0_1[:, :, :, 1:])) + torch.mean(
                                  torch.abs(F_0_1[:, :, :-1, :] -
                                            F_0_1[:, :, 1:, :]))
                loss_smooth = loss_smooth_1_0 + loss_smooth_0_1

                # torch.cuda.empty_cache()
                loss = 204 * recnLoss + 102 * warpLoss + 0.005 * prcpLoss + loss_smooth
                tloss += loss.item()

                # psnr
                MSE_val = MSE_LossFn(Ft_p, IFrame)
                psnr += (10 * log10(1 / MSE_val.item()))
                torch.cuda.empty_cache()

        return (psnr / len(validationloader)), (tloss /
                                                len(validationloader)), retImg

    ### Initialization

    if args.train_continue:
        ArbTimeFlowIntrp.load_state_dict(dict1['state_dictAT'])
        flowComp.load_state_dict(dict1['state_dictFC'])

        optimizer.load_state_dict(dict1.get('state_optimizer', {}))
        scheduler.load_state_dict(dict1.get('state_scheduler', {}))

        for param_group in optimizer.param_groups:
            param_group['lr'] = dict1.get('learningRate',
                                          args.init_learning_rate)

    else:
        dict1 = {'loss': [], 'valLoss': [], 'valPSNR': [], 'epoch': -1}

    ### Training

    import time

    start = time.time()
    cLoss = dict1['loss']
    valLoss = dict1['valLoss']
    valPSNR = dict1['valPSNR']
    checkpoint_counter = 0

    ### Main training loop

    optimizer.step()

    for epoch in range(dict1['epoch'] + 1, args.epochs):
        print("Epoch: ", epoch)

        # Append and reset
        cLoss.append([])
        valLoss.append([])
        valPSNR.append([])
        iLoss = 0

        for trainIndex, (trainData,
                         trainFrameIndex) in enumerate(trainloader, 0):

            ## Getting the input and the target from the training set
            frame0, frameT, frame1 = trainData

            I0 = frame0.to(device)
            I1 = frame1.to(device)
            IFrame = frameT.to(device)
            optimizer.zero_grad()
            # torch.cuda.empty_cache()
            # Calculate flow between reference frames I0 and I1
            flowOut = flowComp(torch.cat((I0, I1), dim=1))

            # Extracting flows between I0 and I1 - F_0_1 and F_1_0
            F_0_1 = flowOut[:, :2, :, :]
            F_1_0 = flowOut[:, 2:, :, :]

            fCoeff = model.getFlowCoeff(trainFrameIndex, device)

            # Calculate intermediate flows
            F_t_0 = fCoeff[0] * F_0_1 + fCoeff[1] * F_1_0
            F_t_1 = fCoeff[2] * F_0_1 + fCoeff[3] * F_1_0

            # Get intermediate frames from the intermediate flows
            g_I0_F_t_0 = trainFlowBackWarp(I0, F_t_0)
            g_I1_F_t_1 = trainFlowBackWarp(I1, F_t_1)
            torch.cuda.empty_cache()
            # Calculate optical flow residuals and visibility maps
            intrpOut = ArbTimeFlowIntrp(
                torch.cat((I0, I1, F_0_1, F_1_0, F_t_1, F_t_0, g_I1_F_t_1,
                           g_I0_F_t_0),
                          dim=1))

            # Extract optical flow residuals and visibility maps
            F_t_0_f = intrpOut[:, :2, :, :] + F_t_0
            F_t_1_f = intrpOut[:, 2:4, :, :] + F_t_1
            V_t_0 = torch.sigmoid(intrpOut[:, 4:5, :, :])
            V_t_1 = 1 - V_t_0
            # torch.cuda.empty_cache()
            # Get intermediate frames from the intermediate flows
            g_I0_F_t_0_f = trainFlowBackWarp(I0, F_t_0_f)
            g_I1_F_t_1_f = trainFlowBackWarp(I1, F_t_1_f)
            # torch.cuda.empty_cache()
            wCoeff = model.getWarpCoeff(trainFrameIndex, device)
            torch.cuda.empty_cache()
            # Calculate final intermediate frame
            Ft_p = (wCoeff[0] * V_t_0 * g_I0_F_t_0_f + wCoeff[1] * V_t_1 *
                    g_I1_F_t_1_f) / (wCoeff[0] * V_t_0 + wCoeff[1] * V_t_1)

            # Loss
            recnLoss = L1_lossFn(Ft_p, IFrame)
            # torch.cuda.empty_cache()

            prcpLoss = MSE_LossFn(vgg16_conv_4_3(Ft_p), vgg16_conv_4_3(IFrame))
            # torch.cuda.empty_cache()
            warpLoss = L1_lossFn(g_I0_F_t_0, IFrame) + L1_lossFn(
                g_I1_F_t_1, IFrame) + L1_lossFn(
                    trainFlowBackWarp(I0, F_1_0), I1) + L1_lossFn(
                        trainFlowBackWarp(I1, F_0_1), I0)

            loss_smooth_1_0 = torch.mean(
                torch.abs(F_1_0[:, :, :, :-1] - F_1_0[:, :, :, 1:])
            ) + torch.mean(torch.abs(F_1_0[:, :, :-1, :] - F_1_0[:, :, 1:, :]))
            loss_smooth_0_1 = torch.mean(
                torch.abs(F_0_1[:, :, :, :-1] - F_0_1[:, :, :, 1:])
            ) + torch.mean(torch.abs(F_0_1[:, :, :-1, :] - F_0_1[:, :, 1:, :]))
            loss_smooth = loss_smooth_1_0 + loss_smooth_0_1
            # torch.cuda.empty_cache()
            # Total Loss - Coefficients 204 and 102 are used instead of 0.8 and 0.4
            # since the loss in paper is calculated for input pixels in range 0-255
            # and the input to our network is in range 0-1
            loss = 204 * recnLoss + 102 * warpLoss + 0.005 * prcpLoss + loss_smooth

            # Backpropagate

            loss.backward()
            optimizer.step()
            scheduler.step(loss.item())

            iLoss += loss.item()
            torch.cuda.empty_cache()
            # Validation and progress every `args.progress_iter` iterations
            if ((trainIndex % args.progress_iter) == args.progress_iter - 1):
                # Increment scheduler count
                scheduler.step(iLoss / args.progress_iter)

                end = time.time()

                psnr, vLoss, valImg = validate()
                optimizer.zero_grad()
                # torch.cuda.empty_cache()
                valPSNR[epoch].append(psnr)
                valLoss[epoch].append(vLoss)

                # Tensorboard
                itr = trainIndex + epoch * (len(trainloader))

                writer.add_scalars(
                    'Loss', {
                        'trainLoss': iLoss / args.progress_iter,
                        'validationLoss': vLoss
                    }, itr)
                writer.add_scalar('PSNR', psnr, itr)

                writer.add_image('Validation', valImg, itr)
                #####

                endVal = time.time()

                print(
                    " Loss: %0.6f  Iterations: %4d/%4d  TrainExecTime: %0.1f  ValLoss:%0.6f  ValPSNR: %0.4f  ValEvalTime: %0.2f LearningRate: %.1e"
                    % (iLoss / args.progress_iter, trainIndex,
                       len(trainloader), end - start, vLoss, psnr,
                       endVal - end, get_lr(optimizer)))

                # torch.cuda.empty_cache()
                cLoss[epoch].append(iLoss / args.progress_iter)
                iLoss = 0
                start = time.time()

        # Create checkpoint after every `args.checkpoint_epoch` epochs
        if (epoch % args.checkpoint_epoch) == args.checkpoint_epoch - 1:
            dict1 = {
                'Detail': "End to end Super SloMo.",
                'epoch': epoch,
                'timestamp': datetime.datetime.now(),
                'trainBatchSz': args.train_batch_size,
                'validationBatchSz': args.validation_batch_size,
                'learningRate': get_lr(optimizer),
                'loss': cLoss,
                'valLoss': valLoss,
                'valPSNR': valPSNR,
                'state_dictFC': flowComp.state_dict(),
                'state_dictAT': ArbTimeFlowIntrp.state_dict(),
                'state_optimizer': optimizer.state_dict(),
                'state_scheduler': scheduler.state_dict()
            }
            torch.save(
                dict1, args.checkpoint_dir + "/SuperSloMo" +
                str(checkpoint_counter) + ".ckpt")
            checkpoint_counter += 1
class TransformerModel(ObservableData):
    '''
    The abstract class of Transformer.

    References:
        - Bahdanau, D., Cho, K., & Bengio, Y. (2014). Neural machine translation by jointly learning to align and translate. arXiv preprint arXiv:1409.0473.
        - Floridi, L., & Chiriatti, M. (2020). GPT-3: Its nature, scope, limits, and consequences. Minds and Machines, 30(4), 681-694.
        - Miller, A., Fisch, A., Dodge, J., Karimi, A. H., Bordes, A., & Weston, J. (2016). Key-value memory networks for directly reading documents. arXiv preprint arXiv:1606.03126.
        - Radford, A., Narasimhan, K., Salimans, T., & Sutskever, I. (2018) Improving Language Understanding by Generative Pre-Training. OpenAI (URL: https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf)
        - Radford, A., Wu, J., Child, R., Luan, D., Amodei, D., & Sutskever, I. (2019). Language models are unsupervised multitask learners. OpenAI blog, 1(8), 9.
        - Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., & Polosukhin, I. (2017). Attention is all you need. arXiv preprint arXiv:1706.03762.

    '''

    __optimizer = None

    def get_optimizer(self):
        return self.__optimizer
    
    def set_optimizer(self, value):
        self.__optimizer = value

    optimizer = property(get_optimizer, set_optimizer)

    __embedding_flag = True

    def get_embedding_flag(self):
        ''' getter '''
        return self.__embedding_flag
    
    def set_embedding_flag(self, value):
        ''' setter '''
        self.__embedding_flag = value

    embedding_flag = property(get_embedding_flag, set_embedding_flag)

    __embedding_weignt = 1.0

    def get_embedding_weignt(self):
        ''' getter '''
        return self.__embedding_weignt
    
    def set_embedding_weignt(self, value):
        ''' setter '''
        self.__embedding_weignt = value
    
    embedding_weignt = property(get_embedding_weignt, set_embedding_weignt)

    def learn(self, iteratable_data):
        '''
        Learn samples drawn by `IteratableData.generate_learned_samples()`.

        Args:
            iteratable_data:     is-a `IteratableData`.
        '''
        if isinstance(iteratable_data, IteratableData) is False:
            raise TypeError("The type of `iteratable_data` must be `IteratableData`.")

        self.__loss_list = []
        learning_rate = self.learning_rate

        pre_batch_observed_arr = None
        pre_test_batch_observed_arr = None
        try:
            epoch = 0
            iter_n = 0
            for batch_observed_arr, batch_target_arr, test_batch_observed_arr, test_batch_target_arr in iteratable_data.generate_learned_samples():
                self.epoch = epoch
                if self.optimizer is not None:
                    self.optimizer.zero_grad()
                    optimizer_setup_flag = True
                else:
                    optimizer_setup_flag = False

                # Self-Attention.
                if len(batch_observed_arr.shape) == 2:
                    batch_observed_arr = torch.unsqueeze(batch_observed_arr, axis=1)

                pred_arr = self.inference(batch_observed_arr, batch_observed_arr)
                loss = self.compute_loss(
                    pred_arr,
                    batch_target_arr
                )
                if optimizer_setup_flag is False:
                    # After initilization, restart.
                    self.optimizer.zero_grad()
                    pred_arr = self.inference(batch_observed_arr, batch_observed_arr)
                    loss = self.compute_loss(
                        pred_arr,
                        test_batch_target_arr
                    )

                loss.backward()
                self.optimizer.step()
                self.regularize()

                if (iter_n+1) % int(iteratable_data.iter_n / iteratable_data.epochs) == 0:
                    with torch.inference_mode():
                        if len(test_batch_observed_arr.shape) == 2:
                            test_batch_observed_arr = torch.unsqueeze(test_batch_observed_arr, axis=1)

                        test_pred_arr = self.inference(test_batch_observed_arr, test_batch_observed_arr)

                        test_loss = self.compute_loss(
                            test_pred_arr,
                            test_batch_observed_arr
                        )

                    _loss = loss.to('cpu').detach().numpy().copy()
                    _test_loss = test_loss.to('cpu').detach().numpy().copy()

                    self.__loss_list.append((_loss, _test_loss))
                    self.logger.debug("Epochs: " + str(epoch + 1) + " Train loss: " + str(_loss) + " Test loss: " + str(_test_loss))
                    epoch += 1
                iter_n += 1

        except KeyboardInterrupt:
            self.logger.debug("Interrupt.")

        self.logger.debug("end. ")
        self.epoch = epoch

    def compute_loss(self, pred_arr, labeled_arr):
        '''
        Compute loss.

        Args:
            pred_arr:       `mxnet.ndarray` or `mxnet.symbol`.
            labeled_arr:    `mxnet.ndarray` or `mxnet.symbol`.

        Returns:
            loss.
        '''
        return self.__computable_loss(pred_arr, labeled_arr)

    def regularize(self):
        '''
        Regularization.
        '''
        if len(self.regularizatable_data_list) > 0:
            params_dict = self.extract_learned_dict()
            for regularizatable in self.regularizatable_data_list:
                params_dict = regularizatable.regularize(params_dict)

            for k, params in params_dict.items():
                self.load_state_dict({k: params}, strict=False)

    def extract_learned_dict(self):
        '''
        Extract (pre-) learned parameters.

        Returns:
            `dict` of the parameters.
        '''
        params_dict = {}
        for k in self.state_dict().keys():
            params_dict.setdefault(k, self.state_dict()[k])

        return params_dict

    def embedding(self, observed_arr):
        '''
        Embedding. In default, this method does the positional encoding.

        Args:
            observed_arr:       `mxnet.ndarray` of observed data points.

        Returns:
            `mxnet.ndarray` of embedded data points.
        '''
        batch_size, seq_len, depth_dim = observed_arr.shape
        self.batch_size = batch_size
        self.seq_len = seq_len
        self.depth_dim = depth_dim

        if self.embedding_flag is False:
            return observed_arr

        arr = torch.from_numpy(np.arange(depth_dim))
        arr = arr.to(observed_arr.device)
        depth_arr = torch.tile(
            torch.unsqueeze(
                (
                    arr / 2
                ).to(torch.int32) * 2,
                0
            ), 
            (seq_len, 1)
        )

        depth_arr = depth_arr / depth_dim
        depth_arr = torch.pow(10000.0, depth_arr).to(torch.float32)

        arr = torch.from_numpy(np.arange(depth_dim))
        arr = arr.to(observed_arr.device)
        phase_arr = torch.tile(
            torch.unsqueeze(
                (
                    arr % 2
                ) * np.pi / 2,
                0
            ), 
            (seq_len, 1)
        )
        arr = torch.from_numpy(np.arange(seq_len))
        arr = arr.to(observed_arr.device)
        positional_arr = torch.tile(
            torch.unsqueeze(
                arr, 
                1
            ), 
            (1, depth_dim)
        )

        sin_arr = torch.sin(positional_arr / depth_arr + phase_arr)

        positional_encoded_arr = torch.tile(
            torch.unsqueeze(sin_arr, 0), 
            (batch_size, 1, 1)
        )

        result_arr = observed_arr + (positional_encoded_arr * self.embedding_weignt)
        return result_arr

    def save_parameters(self, filename):
        '''
        Save parameters to files.

        Args:
            filename:       File name.
        '''
        torch.save(
            {
                'epoch': self.epoch,
                'model_state_dict': self.state_dict(),
                'optimizer_state_dict': self.optimizer.state_dict(),
                'loss': self.loss_arr,
            }, 
            filename
        )

    def load_parameters(self, filename, ctx=None, strict=True):
        '''
        Load parameters to files.

        Args:
            filename:       File name.
            ctx:            Context-manager that changes the selected device.
            strict:         Whether to strictly enforce that the keys in state_dict match the keys returned by this module’s state_dict() function. Default: `True`.
        '''
        checkpoint = torch.load(filename)
        self.load_state_dict(checkpoint['model_state_dict'], strict=strict)

        try:
            if self.optimizer is None:
                if self.optimizer_f is not None:
                    self.optimizer = self.optimizer_f(
                        self.parameters()
                    )
                else:
                    self.optimizer = AdamW(
                        self.parameters(),
                        lr=self.learning_rate,
                        weight_decay=self.weight_decay
                    )

            self.optimizer.load_state_dict(
                checkpoint['optimizer_state_dict']
            )
        except ValueError as e:
            self.logger.debug(e)
            self.logger.debug("The state of the optimizer in `TransformerModel` was not updated.")

        self.epoch = checkpoint['epoch']
        self.__loss_list = checkpoint['loss'].tolist()
        if ctx is not None:
            self.to(ctx)
            self.__ctx = ctx

    __epoch = 0

    def get_epoch(self):
        ''' getter for epoch. '''
        return self.__epoch

    def set_epoch(self, value):
        ''' setter for epoch. '''
        self.__epoch = value

    epoch = property(get_epoch, set_epoch)

    __loss_arr = np.array([])

    def get_loss_arr(self):
        ''' getter for losses. '''
        return self.__loss_arr

    def set_loss_arr(self, value):
        self.__loss_arr = value

    loss_arr = property(get_loss_arr, set_loss_arr)

    # `bool` that means initialization in this class will be deferred or not.
    __init_deferred_flag = False

    def get_init_deferred_flag(self):
        ''' getter for `bool` that means initialization in this class will be deferred or not. '''
        return self.__init_deferred_flag
    
    def set_init_deferred_flag(self, value):
        ''' setter for `bool` that means initialization in this class will be deferred or not. '''
        self.__init_deferred_flag = value

    init_deferred_flag = property(get_init_deferred_flag, set_init_deferred_flag)