예제 #1
0
class ShuffleSelfAttentionModel(Model):
    def __init__(self,
                 save_path,
                 log_path,
                 n_depth,
                 d_features,
                 d_classifier,
                 d_output,
                 threshold=None,
                 stack='ShuffleSelfAttention',
                 expansion_layer='ChannelWiseConvExpansion',
                 mode='1d',
                 optimizer=None,
                 **kwargs):
        '''*args: n_layers, n_head, n_channel, n_vchannel, dropout, use_bottleneck, d_bottleneck'''
        '''
            Arguments:
                mode:   1d:         1d output
                        2d:         2d output
                        residual:   residual output
                        dense:      dense net
        
        '''

        super().__init__(save_path, log_path)
        self.d_output = d_output
        self.threshold = threshold

        # ----------------------------- Model ------------------------------ #
        stack_dict = {
            'ReactionAttention': ReactionAttentionStack,
            'SelfAttention': SelfAttentionStack,
            'Alternate': AlternateStack,
            'Parallel': ParallelStack,
            'ShuffleSelfAttention': ShuffleSelfAttentionStack,
            'ShuffleSelfAttentionStackV2': ShuffleSelfAttentionStackV2,
        }
        expansion_dict = {
            'LinearExpansion': LinearExpansion,
            'ReduceParamLinearExpansion': ReduceParamLinearExpansion,
            'ConvExpansion': ConvExpansion,
            'LinearConvExpansion': LinearConvExpansion,
            'ShuffleConvExpansion': ShuffleConvExpansion,
            'ChannelWiseConvExpansion': ChannelWiseConvExpansion,
        }

        self.model = stack_dict[stack](expansion_dict[expansion_layer],
                                       n_depth=n_depth,
                                       d_features=d_features,
                                       mode=mode,
                                       **kwargs)

        # --------------------------- Classifier --------------------------- #
        if mode == '1d':
            self.classifier = LinearClassifier(d_features, d_classifier,
                                               d_output)
        elif mode == '2d':
            self.classifier = LinearClassifier(n_depth * d_features,
                                               d_classifier, d_output)
        else:
            self.classifier = None

        # ------------------------------ CUDA ------------------------------ #
        # If GPU available, move the graph to GPU(s)

        self.CUDA_AVAILABLE = self.check_cuda()
        if self.CUDA_AVAILABLE:
            # self.model.cuda()
            # self.classifier.cuda()

            device_ids = list(range(torch.cuda.device_count()))
            self.model = nn.DataParallel(self.model, device_ids)
            self.classifier = nn.DataParallel(self.classifier, device_ids)
            self.model.to('cuda')
            self.classifier.to('cuda')
            assert (next(self.model.parameters()).is_cuda)
            assert (next(self.classifier.parameters()).is_cuda)
            pass

        else:
            print('CUDA not found or not enabled, use CPU instead')

        # ---------------------------- Optimizer --------------------------- #
        self.parameters = list(self.model.parameters()) + list(
            self.classifier.parameters())
        if optimizer == None:
            self.optimizer = AdamW(self.parameters,
                                   lr=0.002,
                                   betas=(0.9, 0.999),
                                   weight_decay=0.001)

        # ------------------------ training control ------------------------ #
        self.controller = TrainingControl(max_step=100000,
                                          evaluate_every_nstep=100,
                                          print_every_nstep=10)
        self.early_stopping = EarlyStopping(patience=50)

        # --------------------- logging and tensorboard -------------------- #
        self.set_logger()
        self.set_summary_writer()
        # ---------------------------- END INIT ---------------------------- #

    def checkpoint(self, step):
        checkpoint = {
            'model_state_dict': self.model.state_dict(),
            'classifier_state_dict': self.classifier.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'global_step': step
        }
        return checkpoint

    def train_epoch(self, train_dataloader, eval_dataloader, device, smothing,
                    earlystop):
        ''' Epoch operation in training phase'''

        if device == 'cuda':
            assert self.CUDA_AVAILABLE
        # Set model and classifier training mode
        self.model.train()
        self.classifier.train()

        total_loss = 0
        batch_counter = 0

        # update param per batch
        for batch in tqdm(train_dataloader,
                          mininterval=1,
                          desc='  - (Training)   ',
                          leave=False):  # training_data should be a iterable

            # get data from dataloader

            feature_1, feature_2, y = parse_data(batch, device)

            batch_size = len(feature_1)

            # forward
            self.optimizer.zero_grad()
            logits, attn = self.model(feature_1, feature_2)
            logits = logits.view(batch_size, -1)
            logits = self.classifier(logits)

            # Judge if it's a regression problem
            if self.d_output == 1:
                pred = logits.sigmoid()
                loss = mse_loss(pred, y)

            else:
                pred = logits
                loss = cross_entropy_loss(pred, y, smoothing=smothing)

            # calculate gradients
            loss.backward()

            # update parameters
            self.optimizer.step()

            # get metrics for logging
            acc = accuracy(pred, y, threshold=self.threshold)
            precision, recall, precision_avg, recall_avg = precision_recall(
                pred, y, self.d_output, threshold=self.threshold)
            total_loss += loss.item()
            batch_counter += 1

            # training control
            state_dict = self.controller(batch_counter)

            if state_dict['step_to_print']:
                self.train_logger.info(
                    '[TRAINING]   - step: %5d, loss: %3.4f, acc: %1.4f, pre: %1.4f, rec: %1.4f'
                    % (state_dict['step'], loss, acc, precision[1], recall[1]))
                self.summary_writer.add_scalar('loss/train', loss,
                                               state_dict['step'])
                self.summary_writer.add_scalar('acc/train', acc,
                                               state_dict['step'])
                self.summary_writer.add_scalar('precision/train', precision[1],
                                               state_dict['step'])
                self.summary_writer.add_scalar('recall/train', recall[1],
                                               state_dict['step'])

            if state_dict['step_to_evaluate']:
                stop = self.val_epoch(eval_dataloader, device,
                                      state_dict['step'])
                state_dict['step_to_stop'] = stop

                if earlystop & stop:
                    break

            if self.controller.current_step == self.controller.max_step:
                state_dict['step_to_stop'] = True
                break

        return state_dict

    def val_epoch(self, dataloader, device, step=0, plot=False):
        ''' Epoch operation in evaluation phase '''
        if device == 'cuda':
            assert self.CUDA_AVAILABLE

        # Set model and classifier training mode
        self.model.eval()
        self.classifier.eval()

        # use evaluator to calculate the average performance
        evaluator = Evaluator()

        pred_list = []
        real_list = []

        with torch.no_grad():

            for batch in tqdm(
                    dataloader,
                    mininterval=5,
                    desc='  - (Evaluation)   ',
                    leave=False):  # training_data should be a iterable

                # get data from dataloader
                feature_1, feature_2, y = parse_data(batch, device)

                batch_size = len(feature_1)

                # get logits
                logits, attn = self.model(feature_1, feature_2)
                logits = logits.view(batch_size, -1)
                logits = self.classifier(logits)

                if self.d_output == 1:
                    pred = logits.sigmoid()
                    loss = mse_loss(pred, y)

                else:
                    pred = logits
                    loss = cross_entropy_loss(pred, y, smoothing=False)

                acc = accuracy(pred, y, threshold=self.threshold)
                precision, recall, _, _ = precision_recall(
                    pred, y, self.d_output, threshold=self.threshold)

                # feed the metrics in the evaluator
                evaluator(loss.item(), acc.item(), precision[1].item(),
                          recall[1].item())
                '''append the results to the predict / real list for drawing ROC or PR curve.'''
                if plot:
                    pred_list += pred.tolist()
                    real_list += y.tolist()

            if plot:
                area, precisions, recalls, thresholds = pr(
                    pred_list, real_list)
                plot_pr_curve(recalls, precisions, auc=area)

            # get evaluation results from the evaluator
            loss_avg, acc_avg, pre_avg, rec_avg = evaluator.avg_results()

            self.eval_logger.info(
                '[EVALUATION] - step: %5d, loss: %3.4f, acc: %1.4f, pre: %1.4f, rec: %1.4f'
                % (step, loss_avg, acc_avg, pre_avg, rec_avg))
            self.summary_writer.add_scalar('loss/eval', loss_avg, step)
            self.summary_writer.add_scalar('acc/eval', acc_avg, step)
            self.summary_writer.add_scalar('precision/eval', pre_avg, step)
            self.summary_writer.add_scalar('recall/eval', rec_avg, step)

            state_dict = self.early_stopping(loss_avg)

            if state_dict['save']:
                checkpoint = self.checkpoint(step)
                self.save_model(
                    checkpoint,
                    self.save_path + '-step-%d_loss-%.5f' % (step, loss_avg))

            return state_dict['break']

    def train(self,
              max_epoch,
              train_dataloader,
              eval_dataloader,
              device,
              smoothing=False,
              earlystop=False,
              save_mode='best'):

        assert save_mode in ['all', 'best']
        # train for n epoch
        for epoch_i in range(max_epoch):
            print('[ Epoch', epoch_i, ']')
            # set current epoch
            self.controller.set_epoch(epoch_i + 1)
            # train for on epoch
            state_dict = self.train_epoch(train_dataloader, eval_dataloader,
                                          device, smoothing, earlystop)

            # if state_dict['step_to_stop']:
            #     break

        checkpoint = self.checkpoint(state_dict['step'])

        self.save_model(checkpoint,
                        self.save_path + '-step-%d' % state_dict['step'])

        self.train_logger.info(
            '[INFO]: Finish Training, ends with %d epoch(s) and %d batches, in total %d training steps.'
            %
            (state_dict['epoch'] - 1, state_dict['batch'], state_dict['step']))

    def get_predictions(self,
                        data_loader,
                        device,
                        max_batches=None,
                        activation=None):

        pred_list = []
        real_list = []

        self.model.eval()
        self.classifier.eval()

        batch_counter = 0

        with torch.no_grad():
            for batch in tqdm(data_loader,
                              desc='  - (Testing)   ',
                              leave=False):

                feature_1, feature_2, y = parse_data(batch, device)

                # get logits
                logits, attn = self.model(feature_1, feature_2)
                logits = logits.view(logits.shape[0], -1)
                logits = self.classifier(logits)

                # Whether to apply activation function
                if activation != None:
                    pred = activation(logits)
                else:
                    pred = logits.softmax(dim=-1)
                pred_list += pred.tolist()
                real_list += y.tolist()

                if max_batches != None:
                    batch_counter += 1
                    if batch_counter >= max_batches:
                        break

        return pred_list, real_list
class TransformingAutoEncoderController(nn.Module, ControllableModel):
    '''
    Transforming Auto-Encoder.

    References:
        - Bahdanau, D., Cho, K., & Bengio, Y. (2014). Neural machine translation by jointly learning to align and translate. arXiv preprint arXiv:1409.0473.
        - Floridi, L., & Chiriatti, M. (2020). GPT-3: Its nature, scope, limits, and consequences. Minds and Machines, 30(4), 681-694.
        - Miller, A., Fisch, A., Dodge, J., Karimi, A. H., Bordes, A., & Weston, J. (2016). Key-value memory networks for directly reading documents. arXiv preprint arXiv:1606.03126.
        - Radford, A., Narasimhan, K., Salimans, T., & Sutskever, I. (2018) Improving Language Understanding by Generative Pre-Training. OpenAI (URL: https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf)
        - Radford, A., Wu, J., Child, R., Luan, D., Amodei, D., & Sutskever, I. (2019). Language models are unsupervised multitask learners. OpenAI blog, 1(8), 9.
        - Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., & Polosukhin, I. (2017). Attention is all you need. arXiv preprint arXiv:1706.03762.

    '''
    __loaded_filename = None
    __loaded_ctx = None

    # `bool` that means initialization in this class will be deferred or not.
    __init_deferred_flag = False

    def __init__(
        self,
        computable_loss=None,
        optimizer_f=None,
        encoder=None,
        decoder=None,
        reconstructor=None,
        layer_n=3,
        head_n=3,
        seq_len=5,
        depth_dim=100,
        hidden_dim=100,
        self_attention_activation_list=[],
        multi_head_attention_activation_list=[],
        fc_activation_list=[],
        learning_rate=1e-05,
        weight_decay=0.01,
        dropout_rate=0.5,
        ctx="cpu",
        regularizatable_data_list=[],
        not_init_flag=False,
    ):
        '''
        Init.

        Args:
            computable_loss:                is-a `ComputableLoss` or `gluon.loss`.
            encoder:                        is-a `TransformerModel`.
            decoder:                        is-a `TransformerModel`.
            reconstructor:                  is-a `TransformerModel`.
            layer_n:                        `int` of the number of layers.
            head_n:                         `int` of the number of heads for multi-head attention model.
            seq_len:                        `int` of the length of sequences.
            depth_dim:                      `int` of dimension of dense layer.
            hidden_dim:                     `int` of dimension of hidden(encoder) layer.
            self_attention_activation_list: `list` of `str` of activation function for self-attention model.
            multi_head_attention_activation_list:   `list` of `str` of activation function for multi-head attention model.
            fc_activation_list:             `list` of `str` of activation function in fully-connected layers.
            learning_rate:                  `float` of learning rate.
            learning_attenuate_rate:        `float` of attenuate the `learning_rate` by a factor of this value every `attenuate_epoch`.
            attenuate_epoch:                `int` of attenuate the `learning_rate` by a factor of `learning_attenuate_rate` every `attenuate_epoch`.
            optimizer_name:                 `str` of name of optimizer.
            hybridize_flag:                  Call `mxnet.gluon.HybridBlock.hybridize()` or not.
            scale:                          `float` of scaling factor for initial parameters.
            ctx:                            `mx.cpu()` or `mx.gpu()`.
            initializer:                    is-a `mxnet.initializer` for parameters of model. If `None`, it is drawing from the Xavier distribution.

        '''
        super(TransformingAutoEncoderController, self).__init__()

        if computable_loss is None:
            computable_loss = nn.CrossEntropyLoss()
        self.__computable_loss = computable_loss

        if encoder is None:
            if hidden_dim is None or hidden_dim == depth_dim:
                encoder = TransformerEncoder(
                    depth_dim=depth_dim,
                    layer_n=layer_n,
                    head_n=head_n,
                    self_attention_activation_list=
                    self_attention_activation_list,
                    fc_activation_list=fc_activation_list,
                    computable_loss=computable_loss,
                    not_init_flag=not_init_flag,
                    dropout_rate=dropout_rate,
                    ctx=ctx)
            else:
                encoder = TransformerEncoder(
                    depth_dim=hidden_dim,
                    layer_n=layer_n,
                    head_n=head_n,
                    self_attention_activation_list=
                    self_attention_activation_list,
                    fc_activation_list=fc_activation_list,
                    computable_loss=computable_loss,
                    not_init_flag=not_init_flag,
                    dropout_rate=dropout_rate,
                    ctx=ctx)
            encoder.embedding_flag = False
        else:
            if isinstance(encoder, TransformerModel) is False:
                raise TypeError(
                    "The type of `encoder` must be `TransformerModel`.")

        if decoder is None:
            if hidden_dim is None or hidden_dim == depth_dim:
                decoder = TransformerDecoder(
                    head_n=head_n,
                    depth_dim=depth_dim,
                    layer_n=layer_n,
                    self_attention_activation_list=
                    self_attention_activation_list,
                    multi_head_attention_activation_list=
                    multi_head_attention_activation_list,
                    fc_activation_list=fc_activation_list,
                    computable_loss=computable_loss,
                    not_init_flag=not_init_flag,
                    ctx=ctx)
            else:
                decoder = TransformerDecoder(
                    head_n=head_n,
                    depth_dim=hidden_dim,
                    output_dim=hidden_dim,
                    layer_n=layer_n,
                    self_attention_activation_list=
                    self_attention_activation_list,
                    multi_head_attention_activation_list=
                    multi_head_attention_activation_list,
                    fc_activation_list=fc_activation_list,
                    computable_loss=computable_loss,
                    not_init_flag=not_init_flag,
                    ctx=ctx)
            decoder.embedding_flag = False
        else:
            if isinstance(decoder, TransformerModel) is False:
                raise TypeError(
                    "The type of `decoder` must be `TransformerModel`.")

        if reconstructor is None:
            if hidden_dim is None or hidden_dim == depth_dim:
                reconstructor = TransformerReconstructor(
                    head_n=head_n,
                    depth_dim=depth_dim,
                    layer_n=layer_n,
                    self_attention_activation_list=
                    self_attention_activation_list,
                    fc_activation_list=fc_activation_list,
                    computable_loss=computable_loss,
                    not_init_flag=not_init_flag,
                    ctx=ctx,
                )
            else:
                reconstructor = TransformerReconstructor(
                    head_n=head_n,
                    depth_dim=hidden_dim,
                    output_dim=depth_dim,
                    layer_n=layer_n,
                    self_attention_activation_list=
                    self_attention_activation_list,
                    fc_activation_list=fc_activation_list,
                    computable_loss=computable_loss,
                    not_init_flag=not_init_flag,
                    ctx=ctx,
                )
            reconstructor.embedding_flag = False
        else:
            if isinstance(reconstructor, TransformerModel) is False:
                raise TypeError(
                    "The type of `reconstructor` must be `TransformerModel`.")

        logger = getLogger("accelbrainbase")
        self.logger = logger

        self.encoder = encoder
        self.decoder = decoder
        self.reconstructor = reconstructor

        if hidden_dim is not None and hidden_dim != depth_dim:
            self.encoder_hidden_fc = nn.Linear(
                depth_dim,
                hidden_dim,
                bias=True,
            )
            self.decoder_hidden_fc = nn.Linear(
                depth_dim,
                hidden_dim,
                bias=True,
            )
            init_flag = True
        else:
            self.encoder_hidden_fc = None
            self.decoder_hidden_fc = None
            init_flag = False

        self.__ctx = ctx
        self.to(self.__ctx)

        if self.init_deferred_flag is False:
            if not_init_flag is False:
                if optimizer_f is not None:
                    if init_flag is True:
                        self.optimizer = optimizer_f(self.parameters(), )
                else:
                    if init_flag is True:
                        self.optimizer = AdamW(self.parameters(),
                                               lr=learning_rate,
                                               weight_decay=weight_decay)

        for v in regularizatable_data_list:
            if isinstance(v, RegularizatableData) is False:
                raise TypeError(
                    "The type of values of `regularizatable_data_list` must be `RegularizatableData`."
                )
        self.__regularizatable_data_list = regularizatable_data_list

        self.__learning_rate = learning_rate
        self.__weight_decay = weight_decay
        self.seq_len = seq_len
        self.epoch = 0

    def learn(self, iteratable_data):
        '''
        Learn samples drawn by `IteratableData.generate_learned_samples()`.

        Args:
            iteratable_data:     is-a `TransformerIterator`.
        '''
        if isinstance(iteratable_data, TransformerIterator) is False:
            raise TypeError(
                "The type of `iteratable_data` must be `TransformerIterator`.")

        self.__loss_list = []
        learning_rate = self.__learning_rate

        try:
            epoch = self.epoch
            iter_n = 0
            for encoded_observed_arr, decoded_observed_arr, encoded_mask_arr, decoded_mask_arr, test_encoded_observed_arr, test_decoded_observed_arr, test_encoded_mask_arr, test_decoded_mask_arr, training_target_arr, test_target_arr in iteratable_data.generate_learned_samples(
            ):
                self.epoch = epoch
                if self.encoder.optimizer is not None and self.decoder.optimizer is not None:
                    optimizer_setup_flag = True
                    self.encoder.optimizer.zero_grad()
                    self.decoder.optimizer.zero_grad()
                    self.optimizer.zero_grad()
                else:
                    optimizer_setup_flag = False

                pred_arr = self.inference(encoded_observed_arr,
                                          decoded_observed_arr,
                                          encoded_mask_arr, decoded_mask_arr)
                loss = self.compute_loss(pred_arr, training_target_arr)
                if optimizer_setup_flag is False:
                    self.encoder.optimizer.zero_grad()
                    self.decoder.optimizer.zero_grad()
                    self.optimizer.zero_grad()
                    pred_arr = self.inference(encoded_observed_arr,
                                              decoded_observed_arr,
                                              encoded_mask_arr,
                                              decoded_mask_arr)
                    loss = self.compute_loss(pred_arr, training_target_arr)

                loss.backward()
                self.optimizer.step()
                self.decoder.optimizer.step()
                self.encoder.optimizer.step()
                self.regularize()
                self.decoder.regularize()
                self.encoder.regularize()

                if (iter_n + 1) % int(
                        iteratable_data.iter_n / iteratable_data.epochs) == 0:
                    if torch.inference_mode():
                        test_pred_arr = self.inference(
                            test_encoded_observed_arr,
                            test_decoded_observed_arr, test_encoded_mask_arr,
                            test_decoded_mask_arr)

                        test_loss = self.compute_loss(test_pred_arr,
                                                      test_target_arr)

                    _loss = loss.to('cpu').detach().numpy().copy()
                    _test_loss = test_loss.to('cpu').detach().numpy().copy()

                    self.__loss_list.append((_loss, _test_loss))
                    self.logger.debug("Epochs: " + str(epoch + 1) +
                                      " Train loss: " + str(_loss) +
                                      " Test loss: " + str(_test_loss))
                    epoch += 1
                iter_n += 1

        except KeyboardInterrupt:
            self.logger.debug("Interrupt.")

        self.logger.debug("end. ")
        self.epoch = epoch

    def inference(
        self,
        encoded_observed_arr,
        decoded_observed_arr,
        encoded_mask_arr=None,
        decoded_mask_arr=None,
    ):
        '''
        Inference samples drawn by `IteratableData.generate_inferenced_samples()`.

        Args:
            encoded_observed_arr:   rank-3 Array like or sparse matrix as the observed data points.
                                    The shape is: (batch size, the length of sequence, feature points)

            decoded_observed_arr:   rank-3 Array like or sparse matrix as the observed data points.
                                    The shape is: (batch size, the length of sequence, feature points)

            encoded_mask_arr:       rank-3 Array like or sparse matrix as the observed data points.
                                    The shape is: (batch size, the length of sequence, feature points)

            decoded_mask_arr:       rank-3 Array like or sparse matrix as the observed data points.
                                    The shape is: (batch size, the length of sequence, feature points)

        Returns:
            `mxnet.ndarray` of inferenced feature points.
        '''
        return self(
            encoded_observed_arr,
            decoded_observed_arr,
            encoded_mask_arr,
            decoded_mask_arr,
        )

    def compute_loss(self, pred_arr, labeled_arr):
        '''
        Compute loss.

        Args:
            pred_arr:       `mxnet.ndarray` or `mxnet.symbol`.
            labeled_arr:    `mxnet.ndarray` or `mxnet.symbol`.

        Returns:
            loss.
        '''
        return self.__computable_loss(pred_arr, labeled_arr)

    def forward(
        self,
        encoded_observed_arr,
        decoded_observed_arr,
        encoded_mask_arr=None,
        decoded_mask_arr=None,
    ):
        '''
        Hybrid forward with Gluon API.

        Args:
            F:                      `mxnet.ndarray` or `mxnet.symbol`.
            encoded_observed_arr:   rank-3 Array like or sparse matrix as the observed data points.
                                    The shape is: (batch size, the length of sequence, feature points)

            decoded_observed_arr:   rank-3 Array like or sparse matrix as the observed data points.
                                    The shape is: (batch size, the length of sequence, feature points)

            encoded_mask_arr:       rank-3 Array like or sparse matrix as the observed data points.
                                    The shape is: (batch size, the length of sequence, feature points)

            decoded_mask_arr:       rank-3 Array like or sparse matrix as the observed data points.
                                    The shape is: (batch size, the length of sequence, feature points)
        
        Returns:
            `mxnet.ndarray` or `mxnet.symbol` of inferenced feature points.
        '''
        if self.__loaded_filename is not None:
            loaded_filename = self.__loaded_filename
            self.__loaded_filename = None
            init_encoded_observed_arr = encoded_observed_arr.detach()
            init_decoded_observed_arr = decoded_observed_arr.detach()
            if encoded_mask_arr is not None:
                init_encoded_mask_arr = encoded_mask_arr.detach()
            else:
                init_encoded_mask_arr = None
            if decoded_mask_arr is not None:
                init_decoded_mask_arr = decoded_mask_arr.detach()
            else:
                init_decoded_mask_arr = decoded_mask_arr

            _ = self.forward(
                init_encoded_observed_arr,
                init_decoded_observed_arr,
                init_encoded_mask_arr,
                init_decoded_mask_arr,
            )
            self.load_parameters(loaded_filename, ctx=self.__loaded_ctx)
            self.__loaded_ctx = None

        if encoded_mask_arr is None:
            encoded_mask_arr = torch.ones(
                (encoded_observed_arr.shape[0], 1, 1, 1), )
            encoded_mask_arr = encoded_mask_arr.to(encoded_observed_arr.device)
        if decoded_mask_arr is None:
            decoded_mask_arr = torch.ones(
                (decoded_observed_arr.shape[0], 1, 1, 1), )
            decoded_mask_arr = decoded_mask_arr.to(decoded_observed_arr.device)

        if self.encoder_hidden_fc is not None:
            encoded_observed_arr = self.encoder_hidden_fc(encoded_observed_arr)
        if self.decoder_hidden_fc is not None:
            decoded_observed_arr = self.decoder_hidden_fc(decoded_observed_arr)

        encoded_arr = self.encoder(encoded_observed_arr, encoded_mask_arr)
        decoded_arr = self.decoder(
            decoded_observed_arr,
            encoded_arr,
            decoded_mask_arr,
            encoded_mask_arr,
        )

        self.feature_points_arr = decoded_arr

        reconstructed_arr = self.reconstructor(decoded_arr, None)

        return reconstructed_arr

    def extract_learned_dict(self):
        '''
        Extract (pre-) learned parameters.

        Returns:
            `dict` of the parameters.
        '''
        params_dict = {}
        for k in self.state_dict().keys():
            params_dict.setdefault(k, self.state_dict()[k])

        return params_dict

    def regularize(self):
        '''
        Regularization.
        '''
        if len(self.__regularizatable_data_list) > 0:
            params_dict = self.extract_learned_dict()
            for regularizatable in self.__regularizatable_data_list:
                params_dict = regularizatable.regularize(params_dict)

            for k, params in params_dict.items():
                self.load_state_dict({k: params}, strict=False)

    def __rename_file(self, filename):
        filename_list = filename.split(".")
        _format = filename_list[-1]
        encoder_filename = filename.replace("." + _format,
                                            "_encoder." + _format)
        decoder_filename = filename.replace("." + _format,
                                            "_decoder." + _format)
        reconstructor_filename = filename.replace("." + _format,
                                                  "_reconstructor." + _format)
        return encoder_filename, decoder_filename, reconstructor_filename

    def save_parameters(self, filename):
        '''
        Save parameters to files.

        Args:
            filename:       File name.
        '''
        encoder_filename, decoder_filename, reconstructor_filename = self.__rename_file(
            filename)

        self.encoder.epoch = self.epoch
        self.encoder.loss_arr = self.loss_arr
        self.encoder.save_parameters(encoder_filename)

        self.decoder.epoch = self.epoch
        self.decoder.loss_arr = self.loss_arr
        self.decoder.save_parameters(decoder_filename)

        self.reconstructor.epoch = self.epoch
        self.reconstructor.loss_arr = self.loss_arr
        self.reconstructor.save_parameters(decoder_filename)

        torch.save(
            {
                'model_state_dict': self.state_dict(),
                'optimizer_state_dict': self.optimizer.state_dict(),
                'epoch': self.epoch,
                'loss': self.loss_arr,
            }, filename)

    def load_parameters(self, filename, ctx=None, strict=True):
        '''
        Load parameters to files.

        Args:
            filename:       File name.
            ctx:            Context-manager that changes the selected device.
            strict:         Whether to strictly enforce that the keys in state_dict match the keys returned by this module’s state_dict() function. Default: `True`.
        '''
        try:
            encoder_filename, decoder_filename, reconstructor_filename = self.__rename_file(
                filename)
            self.encoder.load_parameters(encoder_filename,
                                         ctx=ctx,
                                         strict=strict)
            self.decoder.load_parameters(decoder_filename,
                                         ctx=ctx,
                                         strict=strict)
            self.reconstructor.load_parameters(reconstructor_filename,
                                               ctx=ctx,
                                               strict=strict)

            checkpoint = torch.load(filename)
            self.load_state_dict(checkpoint['model_state_dict'], strict=strict)
            self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            self.epoch = checkpoint['epoch']
            self.__loss_list = checkpoint['loss'].tolist()
        except RuntimeError:
            self.__loaded_filename = filename
            self.__loaded_ctx = ctx

        if ctx is not None:
            self.to(ctx)
            self.encoder.to(ctx)
            self.decoder.to(ctx)
            self.reconstructor.to(ctx)
            self.__ctx = ctx

    def set_readonly(self, value):
        ''' setter '''
        raise TypeError("This property must be read-only.")

    def get_init_deferred_flag(self):
        ''' getter for `bool` that means initialization in this class will be deferred or not.'''
        return self.__init_deferred_flag

    def set_init_deferred_flag(self, value):
        ''' setter for `bool` that means initialization in this class will be deferred or not.'''
        self.__init_deferred_flag = value

    init_deferred_flag = property(get_init_deferred_flag,
                                  set_init_deferred_flag)

    __loss_list = []

    def get_loss_arr(self):
        ''' getter '''
        return np.array(self.__loss_list)

    def set_loss_arr(self, value):
        ''' setter '''
        raise TypeError("This property must be read-only.")

    loss_arr = property(get_loss_arr, set_loss_arr)
예제 #3
0
def train():
    global writer
    # For parsing commandline arguments
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--dataset_root",
        type=str,
        required=True,
        help='path to dataset folder containing train-test-validation folders')
    parser.add_argument("--checkpoint_dir",
                        type=str,
                        required=True,
                        help='path to folder for saving checkpoints')
    parser.add_argument("--checkpoint",
                        type=str,
                        help='path of checkpoint for pretrained model')
    parser.add_argument(
        "--train_continue",
        type=bool,
        default=False,
        help=
        'If resuming from checkpoint, set to True and set `checkpoint` path. Default: False.'
    )
    parser.add_argument("--epochs",
                        type=int,
                        default=200,
                        help='number of epochs to train. Default: 200.')
    parser.add_argument("--train_batch_size",
                        type=int,
                        default=3,
                        help='batch size for training. Default: 6.')
    parser.add_argument("--validation_batch_size",
                        type=int,
                        default=6,
                        help='batch size for validation. Default: 10.')
    parser.add_argument("--init_learning_rate",
                        type=float,
                        default=0.0001,
                        help='set initial learning rate. Default: 0.0001.')
    parser.add_argument(
        "--milestones",
        type=list,
        default=[25, 50],
        help=
        'UNUSED NOW: Set to epoch values where you want to decrease learning rate by a factor of 0.1. Default: [100, 150]'
    )
    parser.add_argument(
        "--progress_iter",
        type=int,
        default=200,
        help=
        'frequency of reporting progress and validation. N: after every N iterations. Default: 100.'
    )
    parser.add_argument(
        "--checkpoint_epoch",
        type=int,
        default=5,
        help=
        'checkpoint saving frequency. N: after every N epochs. Each checkpoint is roughly of size 151 MB.Default: 5.'
    )
    args = parser.parse_args()

    ##[TensorboardX](https://github.com/lanpa/tensorboardX)
    ### For visualizing loss and interpolated frames

    ###Initialize flow computation and arbitrary-time flow interpolation CNNs.

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    print(device)
    flowComp = model.UNet(6, 4)
    flowComp.to(device)
    ArbTimeFlowIntrp = model.UNet(20, 5)
    ArbTimeFlowIntrp.to(device)

    ###Initialze backward warpers for train and validation datasets

    train_W_dim = 352
    train_H_dim = 352

    trainFlowBackWarp = model.backWarp(train_W_dim, train_H_dim, device)
    trainFlowBackWarp = trainFlowBackWarp.to(device)
    validationFlowBackWarp = model.backWarp(train_W_dim * 2, train_H_dim,
                                            device)
    validationFlowBackWarp = validationFlowBackWarp.to(device)

    ###Load Datasets

    # Channel wise mean calculated on custom training dataset
    # mean = [0.43702903766008444, 0.43715053433990597, 0.40436416782660994]
    mean = [0.5] * 3
    std = [1, 1, 1]
    normalize = transforms.Normalize(mean=mean, std=std)
    transform = transforms.Compose([transforms.ToTensor(), normalize])

    trainset = dataloader.SuperSloMo(root=args.dataset_root + '/train',
                                     randomCropSize=(train_W_dim, train_H_dim),
                                     transform=transform,
                                     train=True)
    trainloader = torch.utils.data.DataLoader(trainset,
                                              batch_size=args.train_batch_size,
                                              shuffle=True,
                                              num_workers=2,
                                              pin_memory=True)

    validationset = dataloader.SuperSloMo(
        root=args.dataset_root + '/validation',
        transform=transform,
        randomCropSize=(2 * train_W_dim, train_H_dim),
        train=False)
    validationloader = torch.utils.data.DataLoader(
        validationset,
        batch_size=args.validation_batch_size,
        shuffle=False,
        num_workers=2,
        pin_memory=True)

    print(trainset, validationset)

    ###Create transform to display image from tensor

    negmean = [x * -1 for x in mean]
    revNormalize = transforms.Normalize(mean=negmean, std=std)
    TP = transforms.Compose([revNormalize, transforms.ToPILImage()])

    ###Utils

    def get_lr(optimizer):
        for param_group in optimizer.param_groups:
            return param_group['lr']

    ###Loss and Optimizer

    L1_lossFn = nn.L1Loss()
    MSE_LossFn = nn.MSELoss()

    if args.train_continue:
        dict1 = torch.load(args.checkpoint)
        last_epoch = dict1['epoch'] * len(trainloader)
    else:
        last_epoch = -1

    params = list(ArbTimeFlowIntrp.parameters()) + list(flowComp.parameters())

    optimizer = AdamW(params, lr=args.init_learning_rate, amsgrad=True)
    # optimizer = optim.SGD(params, lr=args.init_learning_rate, momentum=0.9, nesterov=True)

    # scheduler to decrease learning rate by a factor of 10 at milestones.
    # Patience suggested value:
    # patience = number of item in train dataset / train_batch_size * (Number of epochs patience)
    # It does say epoch, but in this case, the number of progress iterations is what's really being worked with.
    # As such, each epoch will be given by the above formula (roughly, if using a rough dataset count)
    # If the model seems to equalize fast, reduce the number of epochs accordingly.

    # scheduler = optim.lr_scheduler.CyclicLR(optimizer,
    #                                         base_lr=1e-8,
    #                                         max_lr=9.0e-3,
    #                                         step_size_up=3500,
    #                                         mode='triangular2',
    #                                         cycle_momentum=False,
    #                                         last_epoch=last_epoch)

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=0.1,
        patience=len(trainloader) * 3,
        cooldown=len(trainloader) * 2,
        verbose=True,
        min_lr=1e-8)

    # Changed to use this to ensure a more adaptive model.
    # The changed model used here seems to converge or plateau faster with more rapid swings over time.
    # As such letting the model deal with stagnation more proactively than at a set stage seems more useful.

    ###Initializing VGG16 model for perceptual loss

    vgg16 = torchvision.models.vgg16(pretrained=True)
    vgg16_conv_4_3 = nn.Sequential(*list(vgg16.children())[0][:22])
    vgg16_conv_4_3.to(device)

    for param in vgg16_conv_4_3.parameters():
        param.requires_grad = False

    # Validation function

    def validate():
        # For details see training.
        psnr = 0
        tloss = 0
        flag = 1
        with torch.no_grad():
            for validationIndex, (validationData,
                                  validationFrameIndex) in enumerate(
                                      validationloader, 0):
                frame0, frameT, frame1 = validationData

                I0 = frame0.to(device)
                I1 = frame1.to(device)
                IFrame = frameT.to(device)

                torch.cuda.empty_cache()
                flowOut = flowComp(torch.cat((I0, I1), dim=1))
                F_0_1 = flowOut[:, :2, :, :]
                F_1_0 = flowOut[:, 2:, :, :]

                fCoeff = model.getFlowCoeff(validationFrameIndex, device)
                torch.cuda.empty_cache()
                F_t_0 = fCoeff[0] * F_0_1 + fCoeff[1] * F_1_0
                F_t_1 = fCoeff[2] * F_0_1 + fCoeff[3] * F_1_0

                g_I0_F_t_0 = validationFlowBackWarp(I0, F_t_0)
                g_I1_F_t_1 = validationFlowBackWarp(I1, F_t_1)
                torch.cuda.empty_cache()
                intrpOut = ArbTimeFlowIntrp(
                    torch.cat((I0, I1, F_0_1, F_1_0, F_t_1, F_t_0, g_I1_F_t_1,
                               g_I0_F_t_0),
                              dim=1))

                F_t_0_f = intrpOut[:, :2, :, :] + F_t_0
                F_t_1_f = intrpOut[:, 2:4, :, :] + F_t_1
                V_t_0 = torch.sigmoid(intrpOut[:, 4:5, :, :])
                V_t_1 = 1 - V_t_0
                # torch.cuda.empty_cache()
                g_I0_F_t_0_f = validationFlowBackWarp(I0, F_t_0_f)
                g_I1_F_t_1_f = validationFlowBackWarp(I1, F_t_1_f)

                wCoeff = model.getWarpCoeff(validationFrameIndex, device)
                torch.cuda.empty_cache()
                Ft_p = (wCoeff[0] * V_t_0 * g_I0_F_t_0_f + wCoeff[1] * V_t_1 *
                        g_I1_F_t_1_f) / (wCoeff[0] * V_t_0 + wCoeff[1] * V_t_1)

                # For tensorboard
                if (flag):
                    retImg = torchvision.utils.make_grid([
                        revNormalize(frame0[0]),
                        revNormalize(frameT[0]),
                        revNormalize(Ft_p.cpu()[0]),
                        revNormalize(frame1[0])
                    ],
                                                         padding=10)
                    flag = 0

                # loss
                recnLoss = L1_lossFn(Ft_p, IFrame)
                # torch.cuda.empty_cache()
                prcpLoss = MSE_LossFn(vgg16_conv_4_3(Ft_p),
                                      vgg16_conv_4_3(IFrame))

                warpLoss = L1_lossFn(g_I0_F_t_0, IFrame) + L1_lossFn(
                    g_I1_F_t_1, IFrame) + L1_lossFn(
                        validationFlowBackWarp(I0, F_1_0), I1) + L1_lossFn(
                            validationFlowBackWarp(I1, F_0_1), I0)
                torch.cuda.empty_cache()
                loss_smooth_1_0 = torch.mean(
                    torch.abs(F_1_0[:, :, :, :-1] -
                              F_1_0[:, :, :, 1:])) + torch.mean(
                                  torch.abs(F_1_0[:, :, :-1, :] -
                                            F_1_0[:, :, 1:, :]))
                loss_smooth_0_1 = torch.mean(
                    torch.abs(F_0_1[:, :, :, :-1] -
                              F_0_1[:, :, :, 1:])) + torch.mean(
                                  torch.abs(F_0_1[:, :, :-1, :] -
                                            F_0_1[:, :, 1:, :]))
                loss_smooth = loss_smooth_1_0 + loss_smooth_0_1

                # torch.cuda.empty_cache()
                loss = 204 * recnLoss + 102 * warpLoss + 0.005 * prcpLoss + loss_smooth
                tloss += loss.item()

                # psnr
                MSE_val = MSE_LossFn(Ft_p, IFrame)
                psnr += (10 * log10(1 / MSE_val.item()))
                torch.cuda.empty_cache()

        return (psnr / len(validationloader)), (tloss /
                                                len(validationloader)), retImg

    ### Initialization

    if args.train_continue:
        ArbTimeFlowIntrp.load_state_dict(dict1['state_dictAT'])
        flowComp.load_state_dict(dict1['state_dictFC'])

        optimizer.load_state_dict(dict1.get('state_optimizer', {}))
        scheduler.load_state_dict(dict1.get('state_scheduler', {}))

        for param_group in optimizer.param_groups:
            param_group['lr'] = dict1.get('learningRate',
                                          args.init_learning_rate)

    else:
        dict1 = {'loss': [], 'valLoss': [], 'valPSNR': [], 'epoch': -1}

    ### Training

    import time

    start = time.time()
    cLoss = dict1['loss']
    valLoss = dict1['valLoss']
    valPSNR = dict1['valPSNR']
    checkpoint_counter = 0

    ### Main training loop

    optimizer.step()

    for epoch in range(dict1['epoch'] + 1, args.epochs):
        print("Epoch: ", epoch)

        # Append and reset
        cLoss.append([])
        valLoss.append([])
        valPSNR.append([])
        iLoss = 0

        for trainIndex, (trainData,
                         trainFrameIndex) in enumerate(trainloader, 0):

            ## Getting the input and the target from the training set
            frame0, frameT, frame1 = trainData

            I0 = frame0.to(device)
            I1 = frame1.to(device)
            IFrame = frameT.to(device)
            optimizer.zero_grad()
            # torch.cuda.empty_cache()
            # Calculate flow between reference frames I0 and I1
            flowOut = flowComp(torch.cat((I0, I1), dim=1))

            # Extracting flows between I0 and I1 - F_0_1 and F_1_0
            F_0_1 = flowOut[:, :2, :, :]
            F_1_0 = flowOut[:, 2:, :, :]

            fCoeff = model.getFlowCoeff(trainFrameIndex, device)

            # Calculate intermediate flows
            F_t_0 = fCoeff[0] * F_0_1 + fCoeff[1] * F_1_0
            F_t_1 = fCoeff[2] * F_0_1 + fCoeff[3] * F_1_0

            # Get intermediate frames from the intermediate flows
            g_I0_F_t_0 = trainFlowBackWarp(I0, F_t_0)
            g_I1_F_t_1 = trainFlowBackWarp(I1, F_t_1)
            torch.cuda.empty_cache()
            # Calculate optical flow residuals and visibility maps
            intrpOut = ArbTimeFlowIntrp(
                torch.cat((I0, I1, F_0_1, F_1_0, F_t_1, F_t_0, g_I1_F_t_1,
                           g_I0_F_t_0),
                          dim=1))

            # Extract optical flow residuals and visibility maps
            F_t_0_f = intrpOut[:, :2, :, :] + F_t_0
            F_t_1_f = intrpOut[:, 2:4, :, :] + F_t_1
            V_t_0 = torch.sigmoid(intrpOut[:, 4:5, :, :])
            V_t_1 = 1 - V_t_0
            # torch.cuda.empty_cache()
            # Get intermediate frames from the intermediate flows
            g_I0_F_t_0_f = trainFlowBackWarp(I0, F_t_0_f)
            g_I1_F_t_1_f = trainFlowBackWarp(I1, F_t_1_f)
            # torch.cuda.empty_cache()
            wCoeff = model.getWarpCoeff(trainFrameIndex, device)
            torch.cuda.empty_cache()
            # Calculate final intermediate frame
            Ft_p = (wCoeff[0] * V_t_0 * g_I0_F_t_0_f + wCoeff[1] * V_t_1 *
                    g_I1_F_t_1_f) / (wCoeff[0] * V_t_0 + wCoeff[1] * V_t_1)

            # Loss
            recnLoss = L1_lossFn(Ft_p, IFrame)
            # torch.cuda.empty_cache()

            prcpLoss = MSE_LossFn(vgg16_conv_4_3(Ft_p), vgg16_conv_4_3(IFrame))
            # torch.cuda.empty_cache()
            warpLoss = L1_lossFn(g_I0_F_t_0, IFrame) + L1_lossFn(
                g_I1_F_t_1, IFrame) + L1_lossFn(
                    trainFlowBackWarp(I0, F_1_0), I1) + L1_lossFn(
                        trainFlowBackWarp(I1, F_0_1), I0)

            loss_smooth_1_0 = torch.mean(
                torch.abs(F_1_0[:, :, :, :-1] - F_1_0[:, :, :, 1:])
            ) + torch.mean(torch.abs(F_1_0[:, :, :-1, :] - F_1_0[:, :, 1:, :]))
            loss_smooth_0_1 = torch.mean(
                torch.abs(F_0_1[:, :, :, :-1] - F_0_1[:, :, :, 1:])
            ) + torch.mean(torch.abs(F_0_1[:, :, :-1, :] - F_0_1[:, :, 1:, :]))
            loss_smooth = loss_smooth_1_0 + loss_smooth_0_1
            # torch.cuda.empty_cache()
            # Total Loss - Coefficients 204 and 102 are used instead of 0.8 and 0.4
            # since the loss in paper is calculated for input pixels in range 0-255
            # and the input to our network is in range 0-1
            loss = 204 * recnLoss + 102 * warpLoss + 0.005 * prcpLoss + loss_smooth

            # Backpropagate

            loss.backward()
            optimizer.step()
            scheduler.step(loss.item())

            iLoss += loss.item()
            torch.cuda.empty_cache()
            # Validation and progress every `args.progress_iter` iterations
            if ((trainIndex % args.progress_iter) == args.progress_iter - 1):
                # Increment scheduler count
                scheduler.step(iLoss / args.progress_iter)

                end = time.time()

                psnr, vLoss, valImg = validate()
                optimizer.zero_grad()
                # torch.cuda.empty_cache()
                valPSNR[epoch].append(psnr)
                valLoss[epoch].append(vLoss)

                # Tensorboard
                itr = trainIndex + epoch * (len(trainloader))

                writer.add_scalars(
                    'Loss', {
                        'trainLoss': iLoss / args.progress_iter,
                        'validationLoss': vLoss
                    }, itr)
                writer.add_scalar('PSNR', psnr, itr)

                writer.add_image('Validation', valImg, itr)
                #####

                endVal = time.time()

                print(
                    " Loss: %0.6f  Iterations: %4d/%4d  TrainExecTime: %0.1f  ValLoss:%0.6f  ValPSNR: %0.4f  ValEvalTime: %0.2f LearningRate: %.1e"
                    % (iLoss / args.progress_iter, trainIndex,
                       len(trainloader), end - start, vLoss, psnr,
                       endVal - end, get_lr(optimizer)))

                # torch.cuda.empty_cache()
                cLoss[epoch].append(iLoss / args.progress_iter)
                iLoss = 0
                start = time.time()

        # Create checkpoint after every `args.checkpoint_epoch` epochs
        if (epoch % args.checkpoint_epoch) == args.checkpoint_epoch - 1:
            dict1 = {
                'Detail': "End to end Super SloMo.",
                'epoch': epoch,
                'timestamp': datetime.datetime.now(),
                'trainBatchSz': args.train_batch_size,
                'validationBatchSz': args.validation_batch_size,
                'learningRate': get_lr(optimizer),
                'loss': cLoss,
                'valLoss': valLoss,
                'valPSNR': valPSNR,
                'state_dictFC': flowComp.state_dict(),
                'state_dictAT': ArbTimeFlowIntrp.state_dict(),
                'state_optimizer': optimizer.state_dict(),
                'state_scheduler': scheduler.state_dict()
            }
            torch.save(
                dict1, args.checkpoint_dir + "/SuperSloMo" +
                str(checkpoint_counter) + ".ckpt")
            checkpoint_counter += 1
예제 #4
0
class CienaTransformerModel(Model):
    def __init__(
            self, save_path, log_path, d_features, d_meta, max_length, d_classifier, n_classes, threshold=None,
            optimizer=None, **kwargs):
        '''**kwargs: n_layers, n_head, dropout, use_bottleneck, d_bottleneck'''
        '''
            Arguments:
                save_path -- model file path
                log_path -- log file path
                d_features -- how many PMs
                d_meta -- how many facility types
                max_length -- input sequence length
                d_classifier -- classifier hidden unit
                n_classes -- output dim
                threshold -- if not None, n_classes should be 1 (regression).
        '''

        super().__init__(save_path, log_path)
        self.d_output = n_classes
        self.threshold = threshold
        self.max_length = max_length

        # ----------------------------- Model ------------------------------ #

        self.model = Encoder(TimeFacilityEncoding, d_features=d_features, max_seq_length=max_length,
                                       d_meta=d_meta, **kwargs)

        # --------------------------- Classifier --------------------------- #
        self.classifier = LinearClassifier(d_features * max_length, d_classifier, n_classes)

        # ------------------------------ CUDA ------------------------------ #
        # If GPU available, move the graph to GPU(s)
        self.CUDA_AVAILABLE = self.check_cuda()
        if self.CUDA_AVAILABLE:
            device_ids = list(range(torch.cuda.device_count()))
            self.model = nn.DataParallel(self.model, device_ids)
            self.classifier = nn.DataParallel(self.classifier, device_ids)
            self.model.to('cuda')
            self.classifier.to('cuda')
            assert (next(self.model.parameters()).is_cuda)
            assert (next(self.classifier.parameters()).is_cuda)
            pass

        else:
            print('CUDA not found or not enabled, use CPU instead')

        # ---------------------------- Optimizer --------------------------- #
        self.parameters = list(self.model.parameters()) + list(self.classifier.parameters())
        if optimizer == None:
            self.optimizer = AdamW(self.parameters, lr=0.001, betas=(0.9, 0.999), weight_decay=0.001)

        # ------------------------ training control ------------------------ #
        self.controller = TrainingControl(max_step=100000, evaluate_every_nstep=100, print_every_nstep=10)
        self.early_stopping = EarlyStopping(patience=50)

        # --------------------- logging and tensorboard -------------------- #
        self.count_parameters()
        self.set_logger()
        self.set_summary_writer()
        # ---------------------------- END INIT ---------------------------- #

    def checkpoint(self, step):
        checkpoint = {
            'model_state_dict': self.model.state_dict(),
            'classifier_state_dict': self.classifier.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'global_step': step}
        return checkpoint

    def train_epoch(self, train_dataloader, eval_dataloader, device, smothing, earlystop):
        ''' Epoch operation in training phase'''

        if device == 'cuda':
            assert self.CUDA_AVAILABLE
        # Set model and classifier training mode
        self.model.train()
        self.classifier.train()

        total_loss = 0
        batch_counter = 0

        # update param per batch
        for batch in tqdm(
                train_dataloader, mininterval=1,
                desc='  - (Training)   ', leave=False):  # training_data should be a iterable

            # get data from dataloader

            input_feature_sequence, padding, time_facility, y = map(lambda x: x.to(device), batch)

            batch_size = len(y)

            non_pad_mask, slf_attn_mask = get_attn_mask(padding)

            # forward
            self.optimizer.zero_grad()
            logits, attn = self.model(input_feature_sequence, time_facility, non_pad_mask, slf_attn_mask)
            logits = logits.view(batch_size, -1)
            logits = self.classifier(logits)

            # Judge if it's a regression problem
            if self.d_output == 1:
                pred = logits.sigmoid()
                loss = mse_loss(pred, y)

            else:
                pred = logits
                loss = cross_entropy_loss(pred, y, smoothing=smothing)

            # calculate gradients
            loss.backward()

            # update parameters
            self.optimizer.step()

            # get metrics for logging
            acc = accuracy(pred, y, threshold=self.threshold)
            precision, recall, precision_avg, recall_avg = precision_recall(pred, y, self.d_output,
                                                                            threshold=self.threshold)
            total_loss += loss.item()
            batch_counter += 1

            # training control
            state_dict = self.controller(batch_counter)

            if state_dict['step_to_print']:
                self.train_logger.info(
                    '[TRAINING]   - step: %5d, loss: %3.4f, acc: %1.4f, pre: %1.4f, rec: %1.4f' % (
                        state_dict['step'], loss, acc, precision[1], recall[1]))
                self.summary_writer.add_scalar('loss/train', loss, state_dict['step'])
                self.summary_writer.add_scalar('acc/train', acc, state_dict['step'])
                self.summary_writer.add_scalar('precision/train', precision[1], state_dict['step'])
                self.summary_writer.add_scalar('recall/train', recall[1], state_dict['step'])

            if state_dict['step_to_evaluate']:
                stop = self.val_epoch(eval_dataloader, device, state_dict['step'])
                state_dict['step_to_stop'] = stop

                if earlystop & stop:
                    break

            if self.controller.current_step == self.controller.max_step:
                state_dict['step_to_stop'] = True
                break

        return state_dict

    def val_epoch(self, dataloader, device, step=0, plot=False):
        ''' Epoch operation in evaluation phase '''
        if device == 'cuda':
            assert self.CUDA_AVAILABLE

        # Set model and classifier training mode
        self.model.eval()
        self.classifier.eval()

        # use evaluator to calculate the average performance
        evaluator = Evaluator()

        pred_list = []
        real_list = []

        with torch.no_grad():

            for batch in tqdm(
                    dataloader, mininterval=5,
                    desc='  - (Evaluation)   ', leave=False):  # training_data should be a iterable

                input_feature_sequence, padding, time_facility, y = map(lambda x: x.to(device), batch)

                batch_size = len(y)

                non_pad_mask, slf_attn_mask = get_attn_mask(padding)

                # get logits
                logits, attn = self.model(input_feature_sequence, time_facility, non_pad_mask, slf_attn_mask)
                logits = logits.view(batch_size, -1)
                logits = self.classifier(logits)

                if self.d_output == 1:
                    pred = logits.sigmoid()
                    loss = mse_loss(pred, y)

                else:
                    pred = logits
                    loss = cross_entropy_loss(pred, y, smoothing=False)

                acc = accuracy(pred, y, threshold=self.threshold)
                precision, recall, _, _ = precision_recall(pred, y, self.d_output, threshold=self.threshold)

                # feed the metrics in the evaluator
                evaluator(loss.item(), acc.item(), precision[1].item(), recall[1].item())

                '''append the results to the predict / real list for drawing ROC or PR curve.'''
            #     if plot:
            #         pred_list += pred.tolist()
            #         real_list += y.tolist()
            #
            # if plot:
            #     area, precisions, recalls, thresholds = pr(pred_list, real_list)
            #     plot_pr_curve(recalls, precisions, auc=area)

            # get evaluation results from the evaluator
            loss_avg, acc_avg, pre_avg, rec_avg = evaluator.avg_results()

            self.eval_logger.info(
                '[EVALUATION] - step: %5d, loss: %3.4f, acc: %1.4f, pre: %1.4f, rec: %1.4f' % (
                    step, loss_avg, acc_avg, pre_avg, rec_avg))
            self.summary_writer.add_scalar('loss/eval', loss_avg, step)
            self.summary_writer.add_scalar('acc/eval', acc_avg, step)
            self.summary_writer.add_scalar('precision/eval', pre_avg, step)
            self.summary_writer.add_scalar('recall/eval', rec_avg, step)

            state_dict = self.early_stopping(loss_avg)

            if state_dict['save']:
                checkpoint = self.checkpoint(step)
                self.save_model(checkpoint, self.save_path + '-step-%d_loss-%.5f' % (step, loss_avg))

            return state_dict['break']

    def train(self, max_epoch, train_dataloader, eval_dataloader, device,
              smoothing=False, earlystop=False, save_mode='best'):
        assert save_mode in ['all', 'best']

        # train for n epoch
        for epoch_i in range(max_epoch):
            print('[ Epoch', epoch_i, ']')
            # set current epoch
            self.controller.set_epoch(epoch_i + 1)
            # train for on epoch
            state_dict = self.train_epoch(train_dataloader, eval_dataloader, device, smoothing, earlystop)

            # if state_dict['step_to_stop']:
            #     break

        checkpoint = self.checkpoint(state_dict['step'])

        self.save_model(checkpoint, self.save_path + '-step-%d' % state_dict['step'])

        self.train_logger.info(
            '[INFO]: Finish Training, ends with %d epoch(s) and %d batches, in total %d training steps.' % (
                state_dict['epoch'] - 1, state_dict['batch'], state_dict['step']))

    def get_predictions(self, data_loader, device, max_batches=None, activation=None):

        pred_list = []
        real_list = []

        self.model.eval()
        self.classifier.eval()

        batch_counter = 0

        with torch.no_grad():
            for batch in tqdm(
                    data_loader,
                    desc='  - (Testing)   ', leave=False):

                input_feature_sequence, padding, time_facility, y = map(lambda x: x.to(device), batch)

                non_pad_mask, slf_attn_mask = get_attn_mask(padding)

                # get logits
                logits, attn = self.model(input_feature_sequence, time_facility, non_pad_mask, slf_attn_mask)
                logits = logits.view(logits.shape[0], -1)
                logits = self.classifier(logits)

                # Whether to apply activation function
                if activation != None:
                    pred = activation(logits)
                else:
                    pred = logits.softmax(dim=-1)
                pred_list += pred.tolist()
                real_list += y.tolist()

                if max_batches != None:
                    batch_counter += 1
                    if batch_counter >= max_batches:
                        break

        return pred_list, real_list