'''Loss3: Simplified Professor forcing loss'''
            loss3 = criterion(hids1.view(args.batch_size,-1), hids2.view(args.batch_size,-1).detach())

            '''Total loss = Loss1+Loss2+Loss3'''
            loss = loss1+loss2+loss3

            total_loss += loss.item()

    return total_loss / (nbatch+1)


# Loop over epochs.
if args.resume or args.pretrained:
    print("=> loading checkpoint ")
    checkpoint = torch.load(Path('save', args.data, 'checkpoint', args.filename).with_suffix('.pth'))
    args, start_epoch, best_val_loss = model.load_checkpoint(args,checkpoint,feature_dim)
    optimizer.load_state_dict((checkpoint['optimizer']))
    del checkpoint
    epoch = start_epoch
    print("=> loaded checkpoint")
else:
    epoch = 1
    start_epoch = 1
    best_val_loss = 0
    print("=> Start training from scratch")
print('-' * 89)
print(args)
print('-' * 89)

if not args.pretrained:
    # At any point you can hit Ctrl + C to break out of training early.
                              hids2.view(args.batch_size, -1).detach())
            '''Total loss = Loss1+Loss2+Loss3'''
            loss = loss1 + loss2 + loss3

            total_loss += loss.item()

    return total_loss / (nbatch + 1)


# Loop over epochs.
if args.resume or args.pretrained:
    print("=> loading checkpoint ")
    checkpoint = torch.load(
        Path(args.path_save + '/save', args.data, 'checkpoint',
             args.filename).with_suffix('.pth'))
    args, start_epoch, best_val_loss = model.load_checkpoint(
        args, checkpoint, feature_dim)
    optimizer.load_state_dict((checkpoint['optimizer']))
    del checkpoint
    epoch = start_epoch
    print("=> loaded checkpoint")
else:
    epoch = 1
    start_epoch = 1
    best_val_loss = float('inf')
    print("=> Start training from scratch")
print('-' * 89)
print(args)
print('-' * 89)

if not args.pretrained:
    # At any point you can hit Ctrl + C to break out of training early.
Exemple #3
0
def main():
    """Run training"""
    from model import model

    parser = argparse.ArgumentParser(
        description='PyTorch RNN Prediction Model on Time-series Dataset')
    parser.add_argument(
        '--data',
        type=str,
        default='ecg',
        help=
        'type of the dataset (ecg, gesture, power_demand, space_shuttle, respiration, nyc_taxi'
    )
    parser.add_argument('--filename',
                        type=str,
                        default='chfdb_chf13_45590.pkl',
                        help='filename of the dataset')
    parser.add_argument(
        '--model',
        type=str,
        default='LSTM',
        help='type of recurrent net (RNN_TANH, RNN_RELU, LSTM, GRU, SRU)')
    parser.add_argument('--augment', type=bool, default=True, help='augment')
    parser.add_argument('--emsize',
                        type=int,
                        default=32,
                        help='size of rnn input features')
    parser.add_argument('--nhid',
                        type=int,
                        default=32,
                        help='number of hidden units per layer')
    parser.add_argument('--nlayers',
                        type=int,
                        default=2,
                        help='number of layers')
    parser.add_argument('--res_connection',
                        action='store_true',
                        help='residual connection')
    parser.add_argument('--lr',
                        type=float,
                        default=0.0002,
                        help='initial learning rate')
    parser.add_argument('--weight_decay',
                        type=float,
                        default=1e-4,
                        help='weight decay')
    parser.add_argument('--clip',
                        type=float,
                        default=10,
                        help='gradient clipping')
    parser.add_argument('--epochs',
                        type=int,
                        default=400,
                        help='upper epoch limit')
    parser.add_argument('--batch_size',
                        type=int,
                        default=64,
                        metavar='N',
                        help='batch size')
    parser.add_argument('--eval_batch_size',
                        type=int,
                        default=64,
                        metavar='N',
                        help='eval_batch size')
    parser.add_argument('--bptt', type=int, default=50, help='sequence length')
    parser.add_argument('--teacher_forcing_ratio',
                        type=float,
                        default=0.7,
                        help='teacher forcing ratio (deprecated)')
    parser.add_argument('--dropout',
                        type=float,
                        default=0.2,
                        help='dropout applied to layers (0 = no dropout)')
    parser.add_argument(
        '--tied',
        action='store_true',
        help='tie the word embedding and softmax weights (deprecated)')
    parser.add_argument('--seed', type=int, default=1111, help='random seed')
    parser.add_argument('--device',
                        type=str,
                        default='cuda',
                        help='cuda or cpu')
    parser.add_argument('--log_interval',
                        type=int,
                        default=10,
                        metavar='N',
                        help='report interval')
    parser.add_argument('--save_interval',
                        type=int,
                        default=10,
                        metavar='N',
                        help='save interval')
    parser.add_argument('--save_fig', action='store_true', help='save figure')
    parser.add_argument(
        '--resume',
        '-r',
        help=
        'use checkpoint model parameters as initial parameters (default: False)',
        action="store_true")
    parser.add_argument(
        '--pretrained',
        '-p',
        help=
        'use checkpoint model parameters and do not train anymore (default: False)',
        action="store_true")
    parser.add_argument('--prediction_window_size',
                        type=int,
                        default=10,
                        help='prediction_window_size')
    args = parser.parse_args()
    # Set the random seed manually for reproducibility.
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    ###############################################################################
    # Load data
    ###############################################################################
    TimeseriesData = preprocess_data.PickleDataLoad(
        data_type=args.data,
        filename=args.filename,
        augment_test_data=args.augment)
    train_dataset = TimeseriesData.batchify(args, TimeseriesData.trainData,
                                            args.batch_size)
    test_dataset = TimeseriesData.batchify(args, TimeseriesData.testData,
                                           args.eval_batch_size)
    gen_dataset = TimeseriesData.batchify(args, TimeseriesData.testData, 1)

    ###############################################################################
    # Build the model
    ###############################################################################
    feature_dim = TimeseriesData.trainData.size(1)
    model = model.RNNPredictor(rnn_type=args.model,
                               enc_inp_size=feature_dim,
                               rnn_inp_size=args.emsize,
                               rnn_hid_size=args.nhid,
                               dec_out_size=feature_dim,
                               nlayers=args.nlayers,
                               dropout=args.dropout,
                               tie_weights=args.tied,
                               res_connection=args.res_connection).to(
                                   args.device)
    optimizer = optim.Adam(model.parameters(),
                           lr=args.lr,
                           weight_decay=args.weight_decay)
    criterion = nn.MSELoss()

    ###############################################################################
    # Training code
    ###############################################################################
    def get_batch(args, source, i):
        seq_len = min(args.bptt, len(source) - 1 - i)
        data = source[i:i + seq_len]  # [ seq_len * batch_size * feature_size ]
        target = source[i + 1:i + 1 +
                        seq_len]  # [ (seq_len x batch_size x feature_size) ]
        return data, target

    def generate_output(args,
                        epoch,
                        model,
                        gen_dataset,
                        disp_uncertainty=True,
                        startPoint=500,
                        endPoint=3500):
        if args.save_fig:
            # Turn on evaluation mode which disables dropout.
            model.eval()
            hidden = model.init_hidden(1)
            outSeq = []
            upperlim95 = []
            lowerlim95 = []
            with torch.no_grad():
                for i in range(endPoint):
                    if i >= startPoint:
                        # if disp_uncertainty and epoch > 40:
                        #     outs = []
                        #     model.train()
                        #     for i in range(20):
                        #         out_, hidden_ = model.forward(out+0.01*Variable(torch.randn(out.size())).cuda(),hidden,noise=True)
                        #         outs.append(out_)
                        #     model.eval()
                        #     outs = torch.cat(outs,dim=0)
                        #     out_mean = torch.mean(outs,dim=0) # [bsz * feature_dim]
                        #     out_std = torch.std(outs,dim=0) # [bsz * feature_dim]
                        #     upperlim95.append(out_mean + 2.58*out_std/np.sqrt(20))
                        #     lowerlim95.append(out_mean - 2.58*out_std/np.sqrt(20))

                        out, hidden = model.forward(out, hidden)

                        #print(out_mean,out)

                    else:
                        out, hidden = model.forward(
                            gen_dataset[i].unsqueeze(0), hidden)
                    outSeq.append(out.data.cpu()[0][0].unsqueeze(0))

            outSeq = torch.cat(outSeq, dim=0)  # [seqLength * feature_dim]

            target = preprocess_data.reconstruct(gen_dataset.cpu(),
                                                 TimeseriesData.mean,
                                                 TimeseriesData.std)
            outSeq = preprocess_data.reconstruct(outSeq, TimeseriesData.mean,
                                                 TimeseriesData.std)
            # if epoch>40:
            #     upperlim95 = torch.cat(upperlim95, dim=0)
            #     lowerlim95 = torch.cat(lowerlim95, dim=0)
            #     upperlim95 = preprocess_data.reconstruct(upperlim95.data.cpu().numpy(),TimeseriesData.mean,TimeseriesData.std)
            #     lowerlim95 = preprocess_data.reconstruct(lowerlim95.data.cpu().numpy(),TimeseriesData.mean,TimeseriesData.std)

            plt.figure(figsize=(15, 5))
            for i in range(target.size(-1)):
                plt.plot(target[:, :, i].numpy(),
                         label='Target' + str(i),
                         color='black',
                         marker='.',
                         linestyle='--',
                         markersize=1,
                         linewidth=0.5)
                plt.plot(range(startPoint),
                         outSeq[:startPoint, i].numpy(),
                         label='1-step predictions for target' + str(i),
                         color='green',
                         marker='.',
                         linestyle='--',
                         markersize=1.5,
                         linewidth=1)
                # if epoch>40:
                #     plt.plot(range(startPoint, endPoint), upperlim95[:,i].numpy(), label='upperlim'+str(i),
                #              color='skyblue', marker='.', linestyle='--', markersize=1.5, linewidth=1)
                #     plt.plot(range(startPoint, endPoint), lowerlim95[:,i].numpy(), label='lowerlim'+str(i),
                #              color='skyblue', marker='.', linestyle='--', markersize=1.5, linewidth=1)
                plt.plot(range(startPoint, endPoint),
                         outSeq[startPoint:, i].numpy(),
                         label='Recursive predictions for target' + str(i),
                         color='blue',
                         marker='.',
                         linestyle='--',
                         markersize=1.5,
                         linewidth=1)

            plt.xlim([startPoint - 500, endPoint])
            plt.xlabel('Index', fontsize=15)
            plt.ylabel('Value', fontsize=15)
            plt.title('Time-series Prediction on ' + args.data + ' Dataset',
                      fontsize=18,
                      fontweight='bold')
            plt.legend()
            plt.tight_layout()
            plt.text(startPoint - 500 + 10,
                     target.min(),
                     'Epoch: ' + str(epoch),
                     fontsize=15)
            save_dir = Path(
                'result', args.data,
                args.filename).with_suffix('').joinpath('fig_prediction')
            save_dir.mkdir(parents=True, exist_ok=True)
            plt.savefig(
                save_dir.joinpath('fig_epoch' +
                                  str(epoch)).with_suffix('.png'))
            #plt.show()
            plt.close()
            return outSeq

        else:
            pass

    def evaluate_1step_pred(args, model, test_dataset):
        # Turn on evaluation mode which disables dropout.
        model.eval()
        total_loss = 0
        with torch.no_grad():
            hidden = model.init_hidden(args.eval_batch_size)
            for nbatch, i in enumerate(
                    range(0,
                          test_dataset.size(0) - 1, args.bptt)):

                inputSeq, targetSeq = get_batch(args, test_dataset, i)
                outSeq, hidden = model.forward(inputSeq, hidden)

                loss = criterion(outSeq.view(args.batch_size, -1),
                                 targetSeq.view(args.batch_size, -1))
                hidden = model.repackage_hidden(hidden)
                total_loss += loss.item()

        return total_loss / nbatch

    def train(args, model, train_dataset, epoch):

        with torch.enable_grad():
            # Turn on training mode which enables dropout.
            model.train()
            total_loss = 0
            start_time = time.time()
            hidden = model.init_hidden(args.batch_size)
            for batch, i in enumerate(
                    range(0,
                          train_dataset.size(0) - 1, args.bptt)):
                inputSeq, targetSeq = get_batch(args, train_dataset, i)
                # inputSeq: [ seq_len * batch_size * feature_size ]
                # targetSeq: [ seq_len * batch_size * feature_size ]

                # Starting each batch, we detach the hidden state from how it was previously produced.
                # If we didn't, the model would try backpropagating all the way to start of the dataset.
                hidden = model.repackage_hidden(hidden)
                hidden_ = model.repackage_hidden(hidden)
                optimizer.zero_grad()
                '''Loss1: Free running loss'''
                outVal = inputSeq[0].unsqueeze(0)
                outVals = []
                hids1 = []
                for i in range(inputSeq.size(0)):
                    outVal, hidden_, hid = model.forward(outVal,
                                                         hidden_,
                                                         return_hiddens=True)
                    outVals.append(outVal)
                    hids1.append(hid)
                outSeq1 = torch.cat(outVals, dim=0)
                hids1 = torch.cat(hids1, dim=0)
                loss1 = criterion(
                    outSeq1.contiguous().view(args.batch_size, -1),
                    targetSeq.contiguous().view(args.batch_size, -1))
                '''Loss2: Teacher forcing loss'''
                outSeq2, hidden, hids2 = model.forward(inputSeq,
                                                       hidden,
                                                       return_hiddens=True)
                loss2 = criterion(
                    outSeq2.contiguous().view(args.batch_size, -1),
                    targetSeq.contiguous().view(args.batch_size, -1))
                '''Loss3: Simplified Professor forcing loss'''
                loss3 = criterion(hids1.view(args.batch_size, -1),
                                  hids2.view(args.batch_size, -1).detach())
                '''Total loss = Loss1+Loss2+Loss3'''
                loss = loss1 + loss2 + loss3
                loss.backward()

                # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
                optimizer.step()

                total_loss += loss.item()

                if batch % args.log_interval == 0 and batch > 0:
                    cur_loss = total_loss / args.log_interval
                    elapsed = time.time() - start_time
                    print(
                        '| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.4f} | '
                        'loss {:5.2f} '.format(
                            epoch, batch,
                            len(train_dataset) // args.bptt,
                            elapsed * 1000 / args.log_interval, cur_loss))
                    total_loss = 0
                    start_time = time.time()

    def evaluate(args, model, test_dataset):
        # Turn on evaluation mode which disables dropout.
        model.eval()
        with torch.no_grad():
            total_loss = 0
            hidden = model.init_hidden(args.eval_batch_size)
            nbatch = 1
            for nbatch, i in enumerate(
                    range(0,
                          test_dataset.size(0) - 1, args.bptt)):
                inputSeq, targetSeq = get_batch(args, test_dataset, i)
                # inputSeq: [ seq_len * batch_size * feature_size ]
                # targetSeq: [ seq_len * batch_size * feature_size ]
                hidden_ = model.repackage_hidden(hidden)
                '''Loss1: Free running loss'''
                outVal = inputSeq[0].unsqueeze(0)
                outVals = []
                hids1 = []
                for i in range(inputSeq.size(0)):
                    outVal, hidden_, hid = model.forward(outVal,
                                                         hidden_,
                                                         return_hiddens=True)
                    outVals.append(outVal)
                    hids1.append(hid)
                outSeq1 = torch.cat(outVals, dim=0)
                hids1 = torch.cat(hids1, dim=0)
                loss1 = criterion(
                    outSeq1.contiguous().view(args.batch_size, -1),
                    targetSeq.contiguous().view(args.batch_size, -1))
                '''Loss2: Teacher forcing loss'''
                outSeq2, hidden, hids2 = model.forward(inputSeq,
                                                       hidden,
                                                       return_hiddens=True)
                loss2 = criterion(
                    outSeq2.contiguous().view(args.batch_size, -1),
                    targetSeq.contiguous().view(args.batch_size, -1))
                '''Loss3: Simplified Professor forcing loss'''
                loss3 = criterion(hids1.view(args.batch_size, -1),
                                  hids2.view(args.batch_size, -1).detach())
                '''Total loss = Loss1+Loss2+Loss3'''
                loss = loss1 + loss2 + loss3

                total_loss += loss.item()

        return total_loss / (nbatch + 1)

    # Loop over epochs.
    if args.resume or args.pretrained:
        print("=> loading checkpoint ")
        checkpoint = torch.load(
            Path('save', args.data, 'checkpoint',
                 args.filename).with_suffix('.pth'))
        args, start_epoch, best_val_loss = model.load_checkpoint(
            args, checkpoint, feature_dim)
        optimizer.load_state_dict((checkpoint['optimizer']))
        del checkpoint
        epoch = start_epoch
        print("=> loaded checkpoint")
    else:
        epoch = 1
        start_epoch = 1
        best_val_loss = float('inf')
        print("=> Start training from scratch")
    print('-' * 89)
    print(args)
    print('-' * 89)

    if not args.pretrained:
        # At any point you can hit Ctrl + C to break out of training early.
        try:
            for epoch in range(start_epoch, args.epochs + 1):

                epoch_start_time = time.time()
                train(args, model, train_dataset, epoch)
                val_loss = evaluate(args, model, test_dataset)
                print('-' * 89)
                print(
                    '| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.4f} | '
                    .format(epoch, (time.time() - epoch_start_time), val_loss))
                print('-' * 89)

                generate_output(args,
                                epoch,
                                model,
                                gen_dataset,
                                startPoint=1500)

                if epoch % args.save_interval == 0:
                    # Save the model if the validation loss is the best we've seen so far.
                    is_best = val_loss < best_val_loss
                    best_val_loss = min(val_loss, best_val_loss)
                    model_dictionary = {
                        'epoch': epoch,
                        'best_loss': best_val_loss,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'args': args
                    }
                    model.save_checkpoint(model_dictionary, is_best)

        except KeyboardInterrupt:
            print('-' * 89)
            print('Exiting from training early')

    # Calculate mean and covariance for each channel's prediction errors, and save them with the trained model
    print('=> calculating mean and covariance')
    means, covs = list(), list()
    train_dataset = TimeseriesData.batchify(args,
                                            TimeseriesData.trainData,
                                            bsz=1)
    for channel_idx in range(model.enc_input_size):
        mean, cov = fit_norm_distribution_param(
            args, model, train_dataset[:TimeseriesData.length], channel_idx)
        means.append(mean), covs.append(cov)
    model_dictionary = {
        'epoch': max(epoch, start_epoch),
        'best_loss': best_val_loss,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'args': args,
        'means': means,
        'covs': covs
    }
    model.save_checkpoint(model_dictionary, True)
    print('-' * 89)