Пример #1
0
###############################################################################
TimeseriesData = preprocess_data.PickleDataLoad(data_type=args.data,
                                                filename=args.filename,
                                                augment_test_data=False)
train_dataset = TimeseriesData.batchify(
    args, TimeseriesData.trainData[:TimeseriesData.length], bsz=1)
test_dataset = TimeseriesData.batchify(args, TimeseriesData.testData, bsz=1)

###############################################################################
# Build the model
###############################################################################
nfeatures = TimeseriesData.trainData.size(-1)
model = model.RNNPredictor(rnn_type=args.model,
                           enc_inp_size=nfeatures,
                           rnn_inp_size=args.emsize,
                           rnn_hid_size=args.nhid,
                           dec_out_size=nfeatures,
                           nlayers=args.nlayers,
                           res_connection=args.res_connection).to(args.device)
model.load_state_dict(checkpoint['state_dict'])
#del checkpoint

scores, predicted_scores, precisions, recalls, f_betas = list(), list(), list(
), list(), list()
targets, mean_predictions, oneStep_predictions, Nstep_predictions = list(
), list(), list(), list()
try:
    # For each channel in the dataset
    for channel_idx in range(nfeatures):
        ''' 1. Load mean and covariance if they are pre-calculated, if not calculate them. '''
        # Mean and covariance are calculated on train dataset.
Пример #2
0
def main():
    """Run inference."""

    from model import model

    parser = argparse.ArgumentParser(
        description='PyTorch RNN Anomaly Detection Model')
    parser.add_argument('--prediction_window_size',
                        type=int,
                        default=10,
                        help='prediction_window_size')
    parser.add_argument(
        '--data',
        type=str,
        default='ecg',
        help=
        'type of the dataset (ecg, gesture, power_demand, space_shuttle, respiration, nyc_taxi'
    )
    parser.add_argument('--filename',
                        type=str,
                        default='chfdb_chf13_45590.pkl',
                        help='filename of the dataset')
    parser.add_argument('--save_fig',
                        action='store_true',
                        help='save results as figures')
    parser.add_argument(
        '--compensate',
        action='store_true',
        help='compensate anomaly score using anomaly score esimation')
    parser.add_argument('--beta',
                        type=float,
                        default=1.0,
                        help='beta value for f-beta score')
    parser.add_argument('--device',
                        type=str,
                        default='cuda',
                        help='cuda or cpu')

    args_ = parser.parse_args()
    print('-' * 89)
    print("=> loading checkpoint ")
    if args_.device == 'cpu':
        checkpoint = torch.load(str(
            Path('save', args_.data, 'checkpoint',
                 args_.filename).with_suffix('.pth')),
                                map_location=torch.device('cpu'))
    else:
        checkpoint = torch.load(
            str(
                Path('save', args_.data, 'checkpoint',
                     args_.filename).with_suffix('.pth')))
    args = checkpoint['args']
    args.prediction_window_size = args_.prediction_window_size
    args.beta = args_.beta
    args.save_fig = args_.save_fig
    args.compensate = args_.compensate
    args.device = args_.device
    print("=> loaded checkpoint")

    # Set the random seed manually for reproducibility.
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    ###############################################################################
    # Load data
    ###############################################################################
    TimeseriesData = preprocess_data.PickleDataLoad(data_type=args.data,
                                                    filename=args.filename,
                                                    augment_test_data=False)
    train_dataset = TimeseriesData.batchify(
        args, TimeseriesData.trainData[:TimeseriesData.length], bsz=1)
    test_dataset = TimeseriesData.batchify(args,
                                           TimeseriesData.testData,
                                           bsz=1)

    ###############################################################################
    # Build the model
    ###############################################################################
    nfeatures = TimeseriesData.trainData.size(-1)
    model = model.RNNPredictor(rnn_type=args.model,
                               enc_inp_size=nfeatures,
                               rnn_inp_size=args.emsize,
                               rnn_hid_size=args.nhid,
                               dec_out_size=nfeatures,
                               nlayers=args.nlayers,
                               res_connection=args.res_connection).to(
                                   args.device)
    model.load_state_dict(checkpoint['state_dict'])

    scores, predicted_scores, precisions, recalls, f_betas = list(), list(
    ), list(), list(), list()
    targets, mean_predictions, oneStep_predictions, Nstep_predictions = list(
    ), list(), list(), list()
    try:
        # For each channel in the dataset
        for channel_idx in range(nfeatures):
            ''' 1. Load mean and covariance if they are pre-calculated, if not calculate them. '''
            # Mean and covariance are calculated on train dataset.
            if 'means' in checkpoint.keys() and 'covs' in checkpoint.keys():
                print('=> loading pre-calculated mean and covariance')
                mean, cov = checkpoint['means'][channel_idx], checkpoint[
                    'covs'][channel_idx]
            else:
                print('=> calculating mean and covariance')
                mean, cov = fit_norm_distribution_param(
                    args, model, train_dataset, channel_idx=channel_idx)
            ''' 2. Train anomaly score predictor using support vector regression (SVR). (Optional) '''
            # An anomaly score predictor is trained
            # given hidden layer output and the corresponding anomaly score on train dataset.
            # Predicted anomaly scores on test dataset can be used for the baseline of the adaptive threshold.
            if args.compensate:
                print('=> training an SVR as anomaly score predictor')
                train_score, _, _, hiddens, _ = anomalyScore(
                    args,
                    model,
                    train_dataset,
                    mean,
                    cov,
                    channel_idx=channel_idx)
                score_predictor = GridSearchCV(SVR(),
                                               cv=5,
                                               param_grid={
                                                   "C": [1e0, 1e1, 1e2],
                                                   "gamma":
                                                   np.logspace(-1, 1, 3)
                                               })
                score_predictor.fit(
                    torch.cat(hiddens, dim=0).numpy(),
                    train_score.cpu().numpy())
            else:
                score_predictor = None
            ''' 3. Calculate anomaly scores'''
            # Anomaly scores are calculated on the test dataset
            # given the mean and the covariance calculated on the train dataset
            print('=> calculating anomaly scores')
            score, sorted_prediction, sorted_error, _, predicted_score = anomalyScore(
                args,
                model,
                test_dataset,
                mean,
                cov,
                score_predictor=score_predictor,
                channel_idx=channel_idx)
            ''' 4. Evaluate the result '''
            # The obtained anomaly scores are evaluated by measuring precision, recall, and f_beta scores
            # The precision, recall, f_beta scores are are calculated repeatedly,
            # sampling the threshold from 1 to the maximum anomaly score value, either equidistantly or logarithmically.
            print('=> calculating precision, recall, and f_beta')
            precision, recall, f_beta = get_precision_recall(
                args,
                score,
                num_samples=1000,
                beta=args.beta,
                label=TimeseriesData.testLabel.to(args.device))
            print('data: ', args.data, ' filename: ', args.filename,
                  ' f-beta (no compensation): ',
                  f_beta.max().item(), ' beta: ', args.beta)
            if args.compensate:
                precision, recall, f_beta = get_precision_recall(
                    args,
                    score,
                    num_samples=1000,
                    beta=args.beta,
                    label=TimeseriesData.testLabel.to(args.device),
                    predicted_score=predicted_score)
                print('data: ', args.data, ' filename: ', args.filename,
                      ' f-beta    (compensation): ',
                      f_beta.max().item(), ' beta: ', args.beta)

            target = preprocess_data.reconstruct(
                test_dataset.cpu()[:, 0, channel_idx],
                TimeseriesData.mean[channel_idx],
                TimeseriesData.std[channel_idx]).numpy()
            mean_prediction = preprocess_data.reconstruct(
                sorted_prediction.mean(dim=1).cpu(),
                TimeseriesData.mean[channel_idx],
                TimeseriesData.std[channel_idx]).numpy()
            oneStep_prediction = preprocess_data.reconstruct(
                sorted_prediction[:,
                                  -1].cpu(), TimeseriesData.mean[channel_idx],
                TimeseriesData.std[channel_idx]).numpy()
            Nstep_prediction = preprocess_data.reconstruct(
                sorted_prediction[:,
                                  0].cpu(), TimeseriesData.mean[channel_idx],
                TimeseriesData.std[channel_idx]).numpy()
            sorted_errors_mean = sorted_error.abs().mean(dim=1).cpu()
            sorted_errors_mean *= TimeseriesData.std[channel_idx]
            sorted_errors_mean = sorted_errors_mean.numpy()
            score = score.cpu()
            scores.append(score), targets.append(
                target), predicted_scores.append(predicted_score)
            mean_predictions.append(
                mean_prediction), oneStep_predictions.append(
                    oneStep_prediction)
            Nstep_predictions.append(Nstep_prediction)
            precisions.append(precision), recalls.append(
                recall), f_betas.append(f_beta)

            if args.save_fig:
                save_dir = Path(
                    'result', args.data,
                    args.filename).with_suffix('').joinpath('fig_detection')
                save_dir.mkdir(parents=True, exist_ok=True)
                plt.plot(precision.cpu().numpy(), label='precision')
                plt.plot(recall.cpu().numpy(), label='recall')
                plt.plot(f_beta.cpu().numpy(), label='f1')
                plt.legend()
                plt.xlabel('Threshold (log scale)')
                plt.ylabel('Value')
                plt.title('Anomaly Detection on ' + args.data + ' Dataset',
                          fontsize=18,
                          fontweight='bold')
                plt.savefig(
                    str(
                        save_dir.joinpath('fig_f_beta_channel' +
                                          str(channel_idx)).with_suffix(
                                              '.png')))
                plt.close()

                fig, ax1 = plt.subplots(figsize=(15, 5))
                ax1.plot(target,
                         label='Target',
                         color='black',
                         marker='.',
                         linestyle='--',
                         markersize=1,
                         linewidth=0.5)
                ax1.plot(mean_prediction,
                         label='Mean predictions',
                         color='purple',
                         marker='.',
                         linestyle='--',
                         markersize=1,
                         linewidth=0.5)
                ax1.plot(oneStep_prediction,
                         label='1-step predictions',
                         color='green',
                         marker='.',
                         linestyle='--',
                         markersize=1,
                         linewidth=0.5)
                ax1.plot(Nstep_prediction,
                         label=str(args.prediction_window_size) +
                         '-step predictions',
                         color='blue',
                         marker='.',
                         linestyle='--',
                         markersize=1,
                         linewidth=0.5)
                ax1.plot(sorted_errors_mean,
                         label='Absolute mean prediction errors',
                         color='orange',
                         marker='.',
                         linestyle='--',
                         markersize=1,
                         linewidth=1.0)
                ax1.legend(loc='upper left')
                ax1.set_ylabel('Value', fontsize=15)
                ax1.set_xlabel('Index', fontsize=15)
                ax2 = ax1.twinx()
                ax2.plot(
                    score.numpy().reshape(-1, 1),
                    label=
                    'Anomaly scores from \nmultivariate normal distribution',
                    color='red',
                    marker='.',
                    linestyle='--',
                    markersize=1,
                    linewidth=1)
                if args.compensate:
                    ax2.plot(predicted_score,
                             label='Predicted anomaly scores from SVR',
                             color='cyan',
                             marker='.',
                             linestyle='--',
                             markersize=1,
                             linewidth=1)
                ax2.legend(loc='upper right')
                ax2.set_ylabel('anomaly score', fontsize=15)
                plt.title('Anomaly Detection on ' + args.data + ' Dataset',
                          fontsize=18,
                          fontweight='bold')
                plt.tight_layout()
                plt.xlim([0, len(test_dataset)])
                plt.savefig(
                    str(
                        save_dir.joinpath('fig_scores_channel' +
                                          str(channel_idx)).with_suffix(
                                              '.png')))
                plt.close()

    except KeyboardInterrupt:
        print('-' * 89)
        print('Exiting from training early')

    print('=> saving the results as pickle extensions')
    save_dir = Path('result', args.data, args.filename).with_suffix('')
    save_dir.mkdir(parents=True, exist_ok=True)
    pickle.dump(targets, open(str(save_dir.joinpath('target.pkl')), 'wb'))
    pickle.dump(mean_predictions,
                open(str(save_dir.joinpath('mean_predictions.pkl')), 'wb'))
    pickle.dump(oneStep_predictions,
                open(str(save_dir.joinpath('oneStep_predictions.pkl')), 'wb'))
    pickle.dump(Nstep_predictions,
                open(str(save_dir.joinpath('Nstep_predictions.pkl')), 'wb'))
    pickle.dump(scores, open(str(save_dir.joinpath('score.pkl')), 'wb'))
    pickle.dump(predicted_scores,
                open(str(save_dir.joinpath('predicted_scores.pkl')), 'wb'))
    pickle.dump(precisions, open(str(save_dir.joinpath('precision.pkl')),
                                 'wb'))
    pickle.dump(recalls, open(str(save_dir.joinpath('recall.pkl')), 'wb'))
    pickle.dump(f_betas, open(str(save_dir.joinpath('f_beta.pkl')), 'wb'))
    print('-' * 89)
                                                augment_test_data=args.augment)
train_dataset = TimeseriesData.batchify(args, TimeseriesData.trainData,
                                        args.batch_size)
test_dataset = TimeseriesData.batchify(args, TimeseriesData.testData,
                                       args.eval_batch_size)
gen_dataset = TimeseriesData.batchify(args, TimeseriesData.testData, 1)

###############################################################################
# Build the model
###############################################################################
feature_dim = TimeseriesData.trainData.size(1)
model = model.RNNPredictor(rnn_type=args.model,
                           enc_inp_size=feature_dim,
                           rnn_inp_size=args.emsize,
                           rnn_hid_size=args.nhid,
                           dec_out_size=feature_dim,
                           nlayers=args.nlayers,
                           dropout=args.dropout,
                           tie_weights=args.tied,
                           res_connection=args.res_connection).to(args.device)
optimizer = optim.Adam(model.parameters(),
                       lr=args.lr,
                       weight_decay=args.weight_decay)
criterion = nn.MSELoss()


###############################################################################
# Training code
###############################################################################
def get_batch(args, source, i):
    seq_len = min(args.bptt, len(source) - 1 - i)
    if args.cuda:
        batched_data = batched_data.cuda()

    return batched_data


train_dataset = batchify(TimeseriesData.trainData, 1)[:10000]
test_dataset = batchify(TimeseriesData.testData, 1)


###############################################################################
# Build the model
###############################################################################

model = model.RNNPredictor(rnn_type = args.model, enc_inp_size=3, rnn_inp_size = args.emsize, rnn_hid_size = args.nhid,
                           dec_out_size=3,
                           nlayers = args.nlayers,)

if args.cuda:
    model.cuda()
optimizer = optim.Adam(model.parameters(), lr= 0.0001)
criterion = nn.MSELoss()
###############################################################################
# Training code
###############################################################################


def fit_norm_distribution_param(args, model, train_dataset, endPoint=10000):
    # Turn on evaluation mode which disables dropout.
    model.eval()
    pasthidden = model.init_hidden(1)
Пример #5
0
def main():
    """Run training"""
    from model import model

    parser = argparse.ArgumentParser(
        description='PyTorch RNN Prediction Model on Time-series Dataset')
    parser.add_argument(
        '--data',
        type=str,
        default='ecg',
        help=
        'type of the dataset (ecg, gesture, power_demand, space_shuttle, respiration, nyc_taxi'
    )
    parser.add_argument('--filename',
                        type=str,
                        default='chfdb_chf13_45590.pkl',
                        help='filename of the dataset')
    parser.add_argument(
        '--model',
        type=str,
        default='LSTM',
        help='type of recurrent net (RNN_TANH, RNN_RELU, LSTM, GRU, SRU)')
    parser.add_argument('--augment', type=bool, default=True, help='augment')
    parser.add_argument('--emsize',
                        type=int,
                        default=32,
                        help='size of rnn input features')
    parser.add_argument('--nhid',
                        type=int,
                        default=32,
                        help='number of hidden units per layer')
    parser.add_argument('--nlayers',
                        type=int,
                        default=2,
                        help='number of layers')
    parser.add_argument('--res_connection',
                        action='store_true',
                        help='residual connection')
    parser.add_argument('--lr',
                        type=float,
                        default=0.0002,
                        help='initial learning rate')
    parser.add_argument('--weight_decay',
                        type=float,
                        default=1e-4,
                        help='weight decay')
    parser.add_argument('--clip',
                        type=float,
                        default=10,
                        help='gradient clipping')
    parser.add_argument('--epochs',
                        type=int,
                        default=400,
                        help='upper epoch limit')
    parser.add_argument('--batch_size',
                        type=int,
                        default=64,
                        metavar='N',
                        help='batch size')
    parser.add_argument('--eval_batch_size',
                        type=int,
                        default=64,
                        metavar='N',
                        help='eval_batch size')
    parser.add_argument('--bptt', type=int, default=50, help='sequence length')
    parser.add_argument('--teacher_forcing_ratio',
                        type=float,
                        default=0.7,
                        help='teacher forcing ratio (deprecated)')
    parser.add_argument('--dropout',
                        type=float,
                        default=0.2,
                        help='dropout applied to layers (0 = no dropout)')
    parser.add_argument(
        '--tied',
        action='store_true',
        help='tie the word embedding and softmax weights (deprecated)')
    parser.add_argument('--seed', type=int, default=1111, help='random seed')
    parser.add_argument('--device',
                        type=str,
                        default='cuda',
                        help='cuda or cpu')
    parser.add_argument('--log_interval',
                        type=int,
                        default=10,
                        metavar='N',
                        help='report interval')
    parser.add_argument('--save_interval',
                        type=int,
                        default=10,
                        metavar='N',
                        help='save interval')
    parser.add_argument('--save_fig', action='store_true', help='save figure')
    parser.add_argument(
        '--resume',
        '-r',
        help=
        'use checkpoint model parameters as initial parameters (default: False)',
        action="store_true")
    parser.add_argument(
        '--pretrained',
        '-p',
        help=
        'use checkpoint model parameters and do not train anymore (default: False)',
        action="store_true")
    parser.add_argument('--prediction_window_size',
                        type=int,
                        default=10,
                        help='prediction_window_size')
    args = parser.parse_args()
    # Set the random seed manually for reproducibility.
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    ###############################################################################
    # Load data
    ###############################################################################
    TimeseriesData = preprocess_data.PickleDataLoad(
        data_type=args.data,
        filename=args.filename,
        augment_test_data=args.augment)
    train_dataset = TimeseriesData.batchify(args, TimeseriesData.trainData,
                                            args.batch_size)
    test_dataset = TimeseriesData.batchify(args, TimeseriesData.testData,
                                           args.eval_batch_size)
    gen_dataset = TimeseriesData.batchify(args, TimeseriesData.testData, 1)

    ###############################################################################
    # Build the model
    ###############################################################################
    feature_dim = TimeseriesData.trainData.size(1)
    model = model.RNNPredictor(rnn_type=args.model,
                               enc_inp_size=feature_dim,
                               rnn_inp_size=args.emsize,
                               rnn_hid_size=args.nhid,
                               dec_out_size=feature_dim,
                               nlayers=args.nlayers,
                               dropout=args.dropout,
                               tie_weights=args.tied,
                               res_connection=args.res_connection).to(
                                   args.device)
    optimizer = optim.Adam(model.parameters(),
                           lr=args.lr,
                           weight_decay=args.weight_decay)
    criterion = nn.MSELoss()

    ###############################################################################
    # Training code
    ###############################################################################
    def get_batch(args, source, i):
        seq_len = min(args.bptt, len(source) - 1 - i)
        data = source[i:i + seq_len]  # [ seq_len * batch_size * feature_size ]
        target = source[i + 1:i + 1 +
                        seq_len]  # [ (seq_len x batch_size x feature_size) ]
        return data, target

    def generate_output(args,
                        epoch,
                        model,
                        gen_dataset,
                        disp_uncertainty=True,
                        startPoint=500,
                        endPoint=3500):
        if args.save_fig:
            # Turn on evaluation mode which disables dropout.
            model.eval()
            hidden = model.init_hidden(1)
            outSeq = []
            upperlim95 = []
            lowerlim95 = []
            with torch.no_grad():
                for i in range(endPoint):
                    if i >= startPoint:
                        # if disp_uncertainty and epoch > 40:
                        #     outs = []
                        #     model.train()
                        #     for i in range(20):
                        #         out_, hidden_ = model.forward(out+0.01*Variable(torch.randn(out.size())).cuda(),hidden,noise=True)
                        #         outs.append(out_)
                        #     model.eval()
                        #     outs = torch.cat(outs,dim=0)
                        #     out_mean = torch.mean(outs,dim=0) # [bsz * feature_dim]
                        #     out_std = torch.std(outs,dim=0) # [bsz * feature_dim]
                        #     upperlim95.append(out_mean + 2.58*out_std/np.sqrt(20))
                        #     lowerlim95.append(out_mean - 2.58*out_std/np.sqrt(20))

                        out, hidden = model.forward(out, hidden)

                        #print(out_mean,out)

                    else:
                        out, hidden = model.forward(
                            gen_dataset[i].unsqueeze(0), hidden)
                    outSeq.append(out.data.cpu()[0][0].unsqueeze(0))

            outSeq = torch.cat(outSeq, dim=0)  # [seqLength * feature_dim]

            target = preprocess_data.reconstruct(gen_dataset.cpu(),
                                                 TimeseriesData.mean,
                                                 TimeseriesData.std)
            outSeq = preprocess_data.reconstruct(outSeq, TimeseriesData.mean,
                                                 TimeseriesData.std)
            # if epoch>40:
            #     upperlim95 = torch.cat(upperlim95, dim=0)
            #     lowerlim95 = torch.cat(lowerlim95, dim=0)
            #     upperlim95 = preprocess_data.reconstruct(upperlim95.data.cpu().numpy(),TimeseriesData.mean,TimeseriesData.std)
            #     lowerlim95 = preprocess_data.reconstruct(lowerlim95.data.cpu().numpy(),TimeseriesData.mean,TimeseriesData.std)

            plt.figure(figsize=(15, 5))
            for i in range(target.size(-1)):
                plt.plot(target[:, :, i].numpy(),
                         label='Target' + str(i),
                         color='black',
                         marker='.',
                         linestyle='--',
                         markersize=1,
                         linewidth=0.5)
                plt.plot(range(startPoint),
                         outSeq[:startPoint, i].numpy(),
                         label='1-step predictions for target' + str(i),
                         color='green',
                         marker='.',
                         linestyle='--',
                         markersize=1.5,
                         linewidth=1)
                # if epoch>40:
                #     plt.plot(range(startPoint, endPoint), upperlim95[:,i].numpy(), label='upperlim'+str(i),
                #              color='skyblue', marker='.', linestyle='--', markersize=1.5, linewidth=1)
                #     plt.plot(range(startPoint, endPoint), lowerlim95[:,i].numpy(), label='lowerlim'+str(i),
                #              color='skyblue', marker='.', linestyle='--', markersize=1.5, linewidth=1)
                plt.plot(range(startPoint, endPoint),
                         outSeq[startPoint:, i].numpy(),
                         label='Recursive predictions for target' + str(i),
                         color='blue',
                         marker='.',
                         linestyle='--',
                         markersize=1.5,
                         linewidth=1)

            plt.xlim([startPoint - 500, endPoint])
            plt.xlabel('Index', fontsize=15)
            plt.ylabel('Value', fontsize=15)
            plt.title('Time-series Prediction on ' + args.data + ' Dataset',
                      fontsize=18,
                      fontweight='bold')
            plt.legend()
            plt.tight_layout()
            plt.text(startPoint - 500 + 10,
                     target.min(),
                     'Epoch: ' + str(epoch),
                     fontsize=15)
            save_dir = Path(
                'result', args.data,
                args.filename).with_suffix('').joinpath('fig_prediction')
            save_dir.mkdir(parents=True, exist_ok=True)
            plt.savefig(
                save_dir.joinpath('fig_epoch' +
                                  str(epoch)).with_suffix('.png'))
            #plt.show()
            plt.close()
            return outSeq

        else:
            pass

    def evaluate_1step_pred(args, model, test_dataset):
        # Turn on evaluation mode which disables dropout.
        model.eval()
        total_loss = 0
        with torch.no_grad():
            hidden = model.init_hidden(args.eval_batch_size)
            for nbatch, i in enumerate(
                    range(0,
                          test_dataset.size(0) - 1, args.bptt)):

                inputSeq, targetSeq = get_batch(args, test_dataset, i)
                outSeq, hidden = model.forward(inputSeq, hidden)

                loss = criterion(outSeq.view(args.batch_size, -1),
                                 targetSeq.view(args.batch_size, -1))
                hidden = model.repackage_hidden(hidden)
                total_loss += loss.item()

        return total_loss / nbatch

    def train(args, model, train_dataset, epoch):

        with torch.enable_grad():
            # Turn on training mode which enables dropout.
            model.train()
            total_loss = 0
            start_time = time.time()
            hidden = model.init_hidden(args.batch_size)
            for batch, i in enumerate(
                    range(0,
                          train_dataset.size(0) - 1, args.bptt)):
                inputSeq, targetSeq = get_batch(args, train_dataset, i)
                # inputSeq: [ seq_len * batch_size * feature_size ]
                # targetSeq: [ seq_len * batch_size * feature_size ]

                # Starting each batch, we detach the hidden state from how it was previously produced.
                # If we didn't, the model would try backpropagating all the way to start of the dataset.
                hidden = model.repackage_hidden(hidden)
                hidden_ = model.repackage_hidden(hidden)
                optimizer.zero_grad()
                '''Loss1: Free running loss'''
                outVal = inputSeq[0].unsqueeze(0)
                outVals = []
                hids1 = []
                for i in range(inputSeq.size(0)):
                    outVal, hidden_, hid = model.forward(outVal,
                                                         hidden_,
                                                         return_hiddens=True)
                    outVals.append(outVal)
                    hids1.append(hid)
                outSeq1 = torch.cat(outVals, dim=0)
                hids1 = torch.cat(hids1, dim=0)
                loss1 = criterion(
                    outSeq1.contiguous().view(args.batch_size, -1),
                    targetSeq.contiguous().view(args.batch_size, -1))
                '''Loss2: Teacher forcing loss'''
                outSeq2, hidden, hids2 = model.forward(inputSeq,
                                                       hidden,
                                                       return_hiddens=True)
                loss2 = criterion(
                    outSeq2.contiguous().view(args.batch_size, -1),
                    targetSeq.contiguous().view(args.batch_size, -1))
                '''Loss3: Simplified Professor forcing loss'''
                loss3 = criterion(hids1.view(args.batch_size, -1),
                                  hids2.view(args.batch_size, -1).detach())
                '''Total loss = Loss1+Loss2+Loss3'''
                loss = loss1 + loss2 + loss3
                loss.backward()

                # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
                optimizer.step()

                total_loss += loss.item()

                if batch % args.log_interval == 0 and batch > 0:
                    cur_loss = total_loss / args.log_interval
                    elapsed = time.time() - start_time
                    print(
                        '| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.4f} | '
                        'loss {:5.2f} '.format(
                            epoch, batch,
                            len(train_dataset) // args.bptt,
                            elapsed * 1000 / args.log_interval, cur_loss))
                    total_loss = 0
                    start_time = time.time()

    def evaluate(args, model, test_dataset):
        # Turn on evaluation mode which disables dropout.
        model.eval()
        with torch.no_grad():
            total_loss = 0
            hidden = model.init_hidden(args.eval_batch_size)
            nbatch = 1
            for nbatch, i in enumerate(
                    range(0,
                          test_dataset.size(0) - 1, args.bptt)):
                inputSeq, targetSeq = get_batch(args, test_dataset, i)
                # inputSeq: [ seq_len * batch_size * feature_size ]
                # targetSeq: [ seq_len * batch_size * feature_size ]
                hidden_ = model.repackage_hidden(hidden)
                '''Loss1: Free running loss'''
                outVal = inputSeq[0].unsqueeze(0)
                outVals = []
                hids1 = []
                for i in range(inputSeq.size(0)):
                    outVal, hidden_, hid = model.forward(outVal,
                                                         hidden_,
                                                         return_hiddens=True)
                    outVals.append(outVal)
                    hids1.append(hid)
                outSeq1 = torch.cat(outVals, dim=0)
                hids1 = torch.cat(hids1, dim=0)
                loss1 = criterion(
                    outSeq1.contiguous().view(args.batch_size, -1),
                    targetSeq.contiguous().view(args.batch_size, -1))
                '''Loss2: Teacher forcing loss'''
                outSeq2, hidden, hids2 = model.forward(inputSeq,
                                                       hidden,
                                                       return_hiddens=True)
                loss2 = criterion(
                    outSeq2.contiguous().view(args.batch_size, -1),
                    targetSeq.contiguous().view(args.batch_size, -1))
                '''Loss3: Simplified Professor forcing loss'''
                loss3 = criterion(hids1.view(args.batch_size, -1),
                                  hids2.view(args.batch_size, -1).detach())
                '''Total loss = Loss1+Loss2+Loss3'''
                loss = loss1 + loss2 + loss3

                total_loss += loss.item()

        return total_loss / (nbatch + 1)

    # Loop over epochs.
    if args.resume or args.pretrained:
        print("=> loading checkpoint ")
        checkpoint = torch.load(
            Path('save', args.data, 'checkpoint',
                 args.filename).with_suffix('.pth'))
        args, start_epoch, best_val_loss = model.load_checkpoint(
            args, checkpoint, feature_dim)
        optimizer.load_state_dict((checkpoint['optimizer']))
        del checkpoint
        epoch = start_epoch
        print("=> loaded checkpoint")
    else:
        epoch = 1
        start_epoch = 1
        best_val_loss = float('inf')
        print("=> Start training from scratch")
    print('-' * 89)
    print(args)
    print('-' * 89)

    if not args.pretrained:
        # At any point you can hit Ctrl + C to break out of training early.
        try:
            for epoch in range(start_epoch, args.epochs + 1):

                epoch_start_time = time.time()
                train(args, model, train_dataset, epoch)
                val_loss = evaluate(args, model, test_dataset)
                print('-' * 89)
                print(
                    '| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.4f} | '
                    .format(epoch, (time.time() - epoch_start_time), val_loss))
                print('-' * 89)

                generate_output(args,
                                epoch,
                                model,
                                gen_dataset,
                                startPoint=1500)

                if epoch % args.save_interval == 0:
                    # Save the model if the validation loss is the best we've seen so far.
                    is_best = val_loss < best_val_loss
                    best_val_loss = min(val_loss, best_val_loss)
                    model_dictionary = {
                        'epoch': epoch,
                        'best_loss': best_val_loss,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'args': args
                    }
                    model.save_checkpoint(model_dictionary, is_best)

        except KeyboardInterrupt:
            print('-' * 89)
            print('Exiting from training early')

    # Calculate mean and covariance for each channel's prediction errors, and save them with the trained model
    print('=> calculating mean and covariance')
    means, covs = list(), list()
    train_dataset = TimeseriesData.batchify(args,
                                            TimeseriesData.trainData,
                                            bsz=1)
    for channel_idx in range(model.enc_input_size):
        mean, cov = fit_norm_distribution_param(
            args, model, train_dataset[:TimeseriesData.length], channel_idx)
        means.append(mean), covs.append(cov)
    model_dictionary = {
        'epoch': max(epoch, start_epoch),
        'best_loss': best_val_loss,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'args': args,
        'means': means,
        'covs': covs
    }
    model.save_checkpoint(model_dictionary, True)
    print('-' * 89)
# Load data
###############################################################################
TimeseriesData = preprocess_data.DataLoad(args.data)
train_dataset = preprocess_data.batchify(args, TimeseriesData.trainData,
                                         args.batch_size)
test_dataset = preprocess_data.batchify(args, TimeseriesData.testData,
                                        args.eval_batch_size)
gen_dataset = preprocess_data.batchify(args, TimeseriesData.testData, 1)

###############################################################################
# Build the model
###############################################################################
model = model.RNNPredictor(rnn_type=args.model,
                           enc_inp_size=3,
                           rnn_inp_size=args.emsize,
                           rnn_hid_size=args.nhid,
                           dec_out_size=3,
                           nlayers=args.nlayers,
                           dropout=args.dropout,
                           tie_weights=args.tied)
print(list(model.parameters()))
if args.cuda:
    model.cuda()
optimizer = optim.Adam(model.parameters(), lr=args.lr)
criterion = nn.MSELoss()


###############################################################################
# Training code
###############################################################################
def get_batch(source, i, evaluation=False):
    seq_len = min(args.bptt, len(source) - 1 - i)