def test(epoch, test_loader, model, optimizer):
    model.eval()

    test_loss = 0
    pathTmp=args.path+'TMP/'
    N=len(test_loader.dataset)

    for batch_index, (outY, tmp_index) in enumerate(test_loader):
        outData = readBatch(tmp_index, pathTmp)
        outY=restructure(outY)
        data = Variable(outData)  # sequence length, batch size, input size
        Y = Variable(outY)

        optimizer.zero_grad()

        if args.cuda:
            data = data.cuda()
            if input_type == 'y_projected':
                Y_projected=Y_projected.cuda()
            elif input_type == 'y':
                Y=Y.cuda()

        if input_type == 'y_projected':  # or u or y
            if vae_type == 'deterministic':
                muTheta, mu = model(Y_projected)
            elif vae_type == 'stochastic':
                muTheta, logvarTheta, _ ,_ = model(Y_projected)

        elif input_type == 'y':
            if vae_type == 'deterministic':
                muTheta, mu = model(Y)
            elif vae_type == 'stochastic':
                muTheta, logvarTheta, _ ,_ = model(Y)

        elif input_type == 'u':
            if vae_type == 'deterministic':
                muTheta, mu = model(data)
            elif vae_type == 'stochastic':
                muTheta, logvarTheta, _ ,_ = model(data)


        #muTheta, logvarTheta, _, _ = model(Y_projected)
        loss = loss_function_deterministic(muTheta, data, args)

        #loss = torch.sum((muTheta - data).pow(2)) / (201*batch_size*1862)

        test_loss += loss.data[0]
        #print('Test Epoch: {}, batch:{}'.format(epoch,j))

    if epoch % 50 == 0:
        plotTMP(muTheta.cpu().data, data.cpu().data, epoch)

    avg_error=test_loss *args.batchsize/ N
    print('====> Test Epoch: {} Average Test loss: {:.4f}'.format(
        epoch, avg_error ))
    return torch.FloatTensor([avg_error])
def train(epoch, train_loader, model, optimizer):
    model.train()
    train_loss = 0
    pathTmp = args.path + 'TMP/'
    N = len(train_loader.dataset)

    for batch_index, (outY, label) in enumerate(train_loader):
        #tmp_idx=tmp_index.data
        outData = readBatch(label, pathTmp)
        outY = restructure(outY)
        data = Variable(outData)  # sequence length, batch size, input size
        Y = Variable(outY)

        if isAnnealing:
            if epoch < 50:
                annealParam = 0
            elif epoch < 500:
                annealParam = (epoch / 500)
            else:
                annealParam = 1
        else:
            annealParam = 1
        annealParam = Variable(torch.FloatTensor([annealParam]))
        if args.cuda:
            data = data.cuda()

            Y = Y.cuda()

            #elif input_type == 'u':
            annealParam = annealParam.cuda()

        optimizer.zero_grad()

        if input_type == 'y':
            if vae_type == 'deterministic':
                muTheta, mu = model(Y)
            elif vae_type == 'stochastic':
                muTheta, logvarTheta, mu, logvar = model(Y)

        elif input_type == 'u':
            if vae_type == 'deterministic':
                muTheta, mu = model(data)
            elif vae_type == 'stochastic':
                muTheta, logvarTheta, mu, logvar = model(data)

        if vae_type == 'deterministic':
            loss = loss_function_deterministic(muTheta, data, args)

        elif vae_type == 'stochastic':

            loss = loss_function(muTheta, logvarTheta, data, mu, logvar,
                                 annealParam, args)

        # print(loss)

        loss.backward()
        train_loss += loss.data[0]
        optimizer.step()

        #if batch_index % 120==0:
        #    print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_index, N,100. * batch_index*args.batchsize / N,loss.data[0]))
        # j=j+dataSampleInterval
    train_loss_avg = train_loss * args.batchsize / N
    print('====>Train Epoch: {} Average Train loss: {:.4f}'.format(
        epoch, train_loss_avg))
    return torch.FloatTensor([train_loss_avg])