Example #1
0
def train(obj, optimizer, dataset, xp, args, epoch):

    xp.Timer_Train.reset()
    stats = {}

    for i, x, y in tqdm(optimizer.get_sampler(dataset),
                        desc='Train Epoch',
                        leave=False,
                        total=optimizer.get_sampler_len(dataset)):

        oracle_info = obj.oracle(optimizer.variables.w, x, y)
        oracle_info['i'] = i
        optimizer.step(oracle_info)

        # track statistics for monitoring
        stats['obj'] = float(oracle_info['obj'])
        stats['error'] = float(obj.task_error(optimizer.variables.w, x, y))
        stats['size'] = float(x.size(0))
        update_metrics(xp, stats)

    xp.Timer_Train.update()

    print('\nEpoch: [{0}] (Train) \t'
          '({timer:.2f}s) \t'
          'Obj {obj:.3f}\t'
          'Error {error:.2f}\t'.format(
              int(xp.Epoch.value),
              timer=xp.Timer_Train.value,
              error=xp.Error_Train.value,
              obj=xp.Obj_Train.value,
          ))
    log_metrics(xp, epoch)
Example #2
0
def train(model, loss, optimizer, loader, xp, args):

    model.train()

    xp.Timer_Train.reset()
    stats_dict = {}

    for x, y in tqdm(loader, disable=not args.tqdm, desc='Train Epoch',
                     leave=False, total=len(loader)):
        (x, y) = (x.cuda(), y.cuda()) if args.cuda else (x, y)

        # forward pass
        scores = model(x)

        # compute the loss function, possibly using smoothing
        with set_smoothing_enabled(args.smooth_svm):
            loss_value = loss(scores, y)

        # backward pass
        optimizer.zero_grad()
        loss_value.backward()

        # optimization step
        optimizer.step(lambda: float(loss_value))

        # monitoring
        stats_dict['loss'] = float(loss(scores, y))
        stats_dict['acc'] = float(accuracy(scores, y))
        stats_dict['gamma'] = float(optimizer.gamma)
        stats_dict['size'] = float(scores.size(0))
        update_metrics(xp, stats_dict)

    xp.Eta.update(optimizer.eta)
    xp.Reg.update(regularization(model, args.l2))
    xp.Obj_Train.update(xp.Reg.value + xp.Loss_Train.value)
    xp.Timer_Train.update()

    print('\nEpoch: [{0}] (Train) \t'
          '({timer:.2f}s) \t'
          'Obj {obj:.3f}\t'
          'Loss {loss:.3f}\t'
          'Acc {acc:.2f}%\t'
          .format(int(xp.Epoch.value),
                  timer=xp.Timer_Train.value,
                  acc=xp.Acc_Train.value,
                  obj=xp.Obj_Train.value,
                  loss=xp.Loss_Train.value))

    log_metrics(xp)
Example #3
0
def evaluate(model, loss_fn, test_loader, params, dirs, istest=False):
    '''Evaluate the model on the test set.
    Args:
        model: (torch.nn.Module) the Deep AR model
        loss_fn: a function that takes outputs and labels per timestep, and then computes the loss for the batch
        test_loader: load test data and labels
        params: (Params) hyperparameters
    '''
    logger = logging.getLogger('DeepAR.Eval')
    model.eval()
    with torch.no_grad():
        summary = {}
        metrics = utils.init_metrics(params, dirs)

        for i, (test_batch, labels) in enumerate(tqdm(test_loader)):
            test_batch = test_batch.permute(1, 0, 2).to(torch.float32).to(
                dirs.device)
            labels = labels.to(torch.float32).to(dirs.device)
            batch_size = test_batch.shape[1]
            hidden = model.init_hidden(batch_size)
            cell = model.init_cell(batch_size)

            for t in range(params.pred_start):
                _, hidden, cell = model(test_batch[t].unsqueeze(0), hidden,
                                        cell)

            #save some params of SQF for plot
            if istest and (i == 0):
                plot_param, _, _ = model(
                    test_batch[params.pred_start].unsqueeze(0), hidden, cell)
                save_name = os.path.join(dirs.model_dir, 'sqf_param')
                with open(save_name, 'wb') as f:
                    import pickle
                    pickle.dump(plot_param, f)
                    pickle.dump(test_batch[params.pred_start], f)

            samples, _, _ = model.predict(test_batch,
                                          hidden,
                                          cell,
                                          sampling=True)
            metrics = utils.update_metrics(metrics, samples, labels,
                                           params.pred_start)
        summary = utils.final_metrics(metrics)
        if istest == False:
            strings ='\nCRPS: '+str(summary['CRPS'])+\
                        '\nmre:'+str(summary['mre'].abs().max(dim=1)[0].mean().item())+\
                            '\nPINAW:'+str(summary['pinaw'].item())
            logger.info('- Full test metrics: ' + strings)
        else:
            logger.info(' - Test Set CRPS: ' +
                        str(summary['CRPS'].mean().item()))
    ss_metric = {}
    ss_metric['CRPS_Mean'] = summary['CRPS'].mean()
    ss_metric['mre'] = summary['mre'].abs().mean()
    ss_metric['pinaw'] = summary['pinaw']
    for i, crps in enumerate(summary['CRPS']):
        ss_metric[f'CRPS_{i}'] = crps
    for i, mre in enumerate(summary['mre'].mean(dim=0)):
        ss_metric[f'mre_{i}'] = mre
    return ss_metric
Example #4
0
def evaluate(model, test_loader, params, plot_num):
    '''Evaluate the model on the test set.
    Args:
        model: (torch.nn.Module) the Deep AR model
        loss_fn: a function that takes outputs and labels per timestep, and then computes the loss for the batch
        test_loader: load test data and labels
        params: (Params) hyperparameters
        plot_num: (-1): evaluation from evaluate.py; else (epoch): evaluation on epoch
        sample: (boolean) do ancestral sampling or directly use output mu from last time step
    '''
    model.eval()
    with torch.no_grad():
        plot_batch = np.random.randint(len(test_loader) - 1)
        summary_metric = {}
        raw_metrics = utils.init_metrics()
        sum_mu = torch.zeros([740, params.predict_steps]).to(params.device)
        sum_sigma = torch.zeros([740, params.predict_steps]).to(params.device)
        true = torch.zeros([740, params.predict_steps]).to(params.device)

        for i, (test_batch, id_batch, v,
                labels) in enumerate(tqdm(test_loader)):
            test_batch = test_batch.to(torch.float32).to(params.device)
            id_batch = id_batch.unsqueeze(-1).to(params.device)
            v_batch = v.to(torch.float32).to(params.device)
            labels = labels.to(torch.float32).to(params.device)
            batch_size = test_batch.shape[0]

            sample_mu, sample_q90 = transformer.test(model, params, test_batch,
                                                     v_batch, id_batch)
            raw_metrics = utils.update_metrics(
                raw_metrics,
                sample_mu,
                labels,
                params.test_predict_start,
                relative=params.relative_metrics)

            if (i == 0):
                sum_mu = sample_mu
                sum_q90 = sample_q90
                true = labels[:, -params.predict_steps:]
            else:
                sum_mu = torch.cat([sum_mu, sample_mu], 0)
                sum_q90 = torch.cat([sum_q90, sample_q90], 0)
                true = torch.cat([true, labels[:, -params.predict_steps:]], 0)

        summary_metric = utils.final_metrics(raw_metrics)
        summary_metric['q50'] = transformer.quantile_loss(0.5, sum_mu, true)
        summary_metric['q90'] = transformer.quantile_loss(0.5, sum_q90, true)
        summary_metric['MAPE'] = transformer.MAPE(sum_mu, true)

        metrics_string = '; '.join('{}: {:05.3f}'.format(k, v)
                                   for k, v in summary_metric.items())

        logger.info('- Full test metrics: ' + metrics_string)
    return summary_metric
Example #5
0
def evaluate_dataset(model, dataloader):
    model.eval()
    running_metrics = None
    for images_lowres, images_fullres, targets in dataloader:
        images_lowres, images_fullres, targets = utils.device(
            [images_lowres, images_fullres, targets])
        with torch.no_grad():
            predictions = model(images_lowres, images_fullres)
        psnr = compute_psnr(predictions, targets)
        running_metrics = utils.update_metrics(running_metrics, {'psnr': psnr},
                                               len(dataloader))
    model.train()
    return running_metrics
Example #6
0
def evaluate(model, loss_fn, test_loader, params, plot_num, sample=True):
    '''Evaluate the model on the test set.
    Args:
        model: (torch.nn.Module) the Deep AR model
        loss_fn: a function that takes outputs and labels per timestep, and then computes the loss for the batch
        test_loader: load test data and labels
        params: (Params) hyperparameters
        plot_num: (-1): evaluation from evaluate.py; else (epoch): evaluation on epoch
        sample: (boolean) do ancestral sampling or directly use output mu from last time step
    '''
    model.eval()
    with torch.no_grad():
      plot_batch = np.random.randint(len(test_loader)-1)

      summary_metric = {}
      raw_metrics = utils.init_metrics(sample=sample)

      # Test_loader: 
      # test_batch ([batch_size, train_window, 1+cov_dim]): z_{0:T-1} + x_{1:T}, note that z_0 = 0;
      # id_batch ([batch_size]): one integer denoting the time series id;
      # v ([batch_size, 2]): scaling factor for each window;
      # labels ([batch_size, train_window]): z_{1:T}.
      for i, (test_batch, id_batch, v, labels) in enumerate(tqdm(test_loader)):
          test_batch = test_batch.permute(1, 0, 2).to(torch.float32).to(params.device)
          id_batch = id_batch.unsqueeze(0).to(params.device)
          v_batch = v.to(torch.float32).to(params.device)
          labels = labels.to(torch.float32).to(params.device)
          batch_size = test_batch.shape[1]
          input_mu = torch.zeros(batch_size, params.test_predict_start, device=params.device) # scaled
          input_sigma = torch.zeros(batch_size, params.test_predict_start, device=params.device) # scaled
          hidden = model.init_hidden(batch_size)
          cell = model.init_cell(batch_size)

          for t in range(params.test_predict_start):
              # if z_t is missing, replace it by output mu from the last time step
              zero_index = (test_batch[t,:,0] == 0)
              if t > 0 and torch.sum(zero_index) > 0:
                  test_batch[t,zero_index,0] = mu[zero_index]

              mu, sigma, hidden, cell = model(test_batch[t].unsqueeze(0), id_batch, hidden, cell)
              input_mu[:,t] = v_batch[:, 0] * mu + v_batch[:, 1]
              input_sigma[:,t] = v_batch[:, 0] * sigma

          if sample:
              samples, sample_mu, sample_sigma = model.test(test_batch, v_batch, id_batch, hidden, cell, sampling=True)
              raw_metrics = utils.update_metrics(raw_metrics, input_mu, input_sigma, sample_mu, labels, params.test_predict_start, samples, relative = params.relative_metrics)
          else:
              sample_mu, sample_sigma = model.test(test_batch, v_batch, id_batch, hidden, cell)
              raw_metrics = utils.update_metrics(raw_metrics, input_mu, input_sigma, sample_mu, labels, params.test_predict_start, relative = params.relative_metrics)

          if i == plot_batch:
              if sample:
                  sample_metrics = utils.get_metrics(sample_mu, labels, params.test_predict_start, samples, relative = params.relative_metrics)
              else:
                  sample_metrics = utils.get_metrics(sample_mu, labels, params.test_predict_start, relative = params.relative_metrics)                
              # select 10 from samples with highest error and 10 from the rest
              top_10_nd_sample = (-sample_metrics['ND']).argsort()[:batch_size // 10]  # hard coded to be 10
              chosen = set(top_10_nd_sample.tolist())
              all_samples = set(range(batch_size))
              not_chosen = np.asarray(list(all_samples - chosen))
              if batch_size < 100: # make sure there are enough unique samples to choose top 10 from
                  random_sample_10 = np.random.choice(top_10_nd_sample, size=10, replace=True)
              else:
                  random_sample_10 = np.random.choice(top_10_nd_sample, size=10, replace=False)
              if batch_size < 12: # make sure there are enough unique samples to choose bottom 90 from
                  random_sample_90 = np.random.choice(not_chosen, size=10, replace=True)
              else:
                  random_sample_90 = np.random.choice(not_chosen, size=10, replace=False)
              combined_sample = np.concatenate((random_sample_10, random_sample_90))

              label_plot = labels[combined_sample].data.cpu().numpy()
              predict_mu = sample_mu[combined_sample].data.cpu().numpy()
              predict_sigma = sample_sigma[combined_sample].data.cpu().numpy()
              plot_mu = np.concatenate((input_mu[combined_sample].data.cpu().numpy(), predict_mu), axis=1)
              plot_sigma = np.concatenate((input_sigma[combined_sample].data.cpu().numpy(), predict_sigma), axis=1)
              plot_metrics = {_k: _v[combined_sample] for _k, _v in sample_metrics.items()}
              plot_eight_windows(params.plot_dir, plot_mu, plot_sigma, label_plot, params.test_window, params.test_predict_start, plot_num, plot_metrics, sample)

      summary_metric = utils.final_metrics(raw_metrics, sampling=sample)
      metrics_string = '; '.join('{}: {:05.3f}'.format(k, v) for k, v in summary_metric.items())
      logger.info('- Full test metrics: ' + metrics_string)
    return summary_metric
Example #7
0
def set_df(df, settings, metrics, db):
    machine_type = __name__.split('.')[-1]
    S, MIN_D, MAX_X, MAX_Y, USE_BIGCHAINDB = (settings[k] for k in 
                                              ('S', 'MIN_D', 'MAX_X', 'MAX_Y', 'USE_BIGCHAINDB'))
    s = (df['machine_type']==machine_type)
    
    speed, reward, penalty = (settings['machines'][machine_type][k] for k in ('speed', 'reward', 'penalty'))
    
#    print('********************************', speed, reward, penalty)
    
    
    col_x_trg, col_y_trg = 'x_trg', 'y_trg'                

#    df.loc[s,'x_near'] = df.loc[s,'x_trg']
#    df.loc[s,'y_near'] = df.loc[s,'y_trg']
    for ix, row in df.loc[s].iterrows():
        #print(i,r)
        neighbour_ix = dist(row.x,row.y,df[df.index!=ix].x,df[df.index!=ix].y).idxmin()
        # if I am not the nearest set nearest as target
#        if ix!=neighbour_ix:
        df.loc[ix,'ix_near'] = int(neighbour_ix)
        df.loc[ix,'x_near'] = df.loc[neighbour_ix,'x']
        df.loc[ix,'y_near'] = df.loc[neighbour_ix,'y']
   
    # reward
    #reward = 1.0
    df.loc[s,'ds'] = dist(df.loc[s,'x'], df.loc[s,'y'], df.loc[s,col_x_trg], df.loc[s,col_y_trg])
    df.loc[s,'dx'] = df.loc[s,col_x_trg] - df.loc[s,'x']
    df.loc[s,'dy'] = df.loc[s,col_y_trg] - df.loc[s,'y']

    # penalty
    #penalty = 1.0#1.0
    df.loc[s,'ds1'] =  dist(df.loc[s,'x'], df.loc[s,'y'], df.loc[s,'x_near'], df.loc[s,'y_near'])
    df.loc[s,'dx1'] = -(df.loc[s,'x_near'] - df.loc[s,'x'])
    df.loc[s,'dy1'] = -(df.loc[s,'y_near'] - df.loc[s,'y'])

   
    # speed is max(the remaining dist, S) ds is always positive
    selection = s & (df['ds']>=MIN_D)
    #if not selection.empty:
    fr = df.loc[selection]
    df.loc[selection,'u'] = fr.ds.clip_upper(reward*speed) * fr['dx']/fr['ds']
    df.loc[selection,'v'] = fr.ds.clip_upper(reward*speed) * fr['dy']/fr['ds']

    #df.loc[selection,'u1'] = fr.ds1.clip_upper(penalty*speed) * fr['dx1']/(fr['ds1']**3)
    #df.loc[selection,'v1'] = fr.ds1.clip_upper(penalty*speed) * fr['dy1']/(fr['ds1']**3)

    df.loc[selection,'u1'] = (1.0/ (fr['dx1']**2)).clip_upper(penalty*speed)
    df.loc[selection,'v1'] = (1.0/ (fr['dy1']**2)).clip_upper(penalty*speed)


     # TODO: doesn't really work
    #selection = s & (df['ds']<MIN_Dspeed)
    df.loc[selection,'u'] += df.loc[selection,'u1']
    df.loc[selection,'v'] += df.loc[selection,'v1']
    
    # on collision turn left
    selection = s & (df['collision'])
#    fr = df.loc[selection]
    # choose random vector, noise evolution
#    df.loc[selection, 'u']=rnd_vec(len(df.loc[selection]),2.0)-1.0
#    df.loc[selection, 'v']=rnd_vec(len(df.loc[selection]),2.0)-1.0
#    df.loc[selection, 'ds']=(df.loc[selection, 'u']**2 + df.loc[selection, 'v']**2)**0.5
#    df.loc[selection,'u'] = fr.ds.clip_upper(speed) * fr['u']/fr['ds']
#    df.loc[selection,'v'] = fr.ds.clip_upper(speed) * fr['v']/fr['ds']
#    df.loc[selection,'u'] = fr.ds.clip_upper(speed) * fr['u']/fr['ds']
#    df.loc[selection,'v'] = fr.ds.clip_upper(speed) * fr['v']/fr['ds']
    
    u = df.loc[selection, 'v']
    v = -df.loc[selection, 'u']
    df.loc[selection, 'u'] = u
    df.loc[selection, 'v'] = v
    update_metrics(db, metrics, 'collisions', machine_type, len(df.loc[selection]))

    # determine if target is reached and set new target
    # new target should come from blockchain
    selection = s & (df['ds'].abs()<MIN_D)
    #if not selection.empty:
    l = len(df.loc[selection])
    if l>0:
        print('targets reached', l)
    df.loc[selection,'x_trg'] = rnd_vec(l,MAX_X)
    df.loc[selection, 'y_trg'] = rnd_vec(l,MAX_Y)

# alex
# this code uses simple binary state change to describe whether the robot is carrying
# an asset (for example energy or a parcel or whatever) 
    update_metrics(db, metrics, 'pickups', machine_type, len(df[selection & (df['state']=='empty')]))
    update_metrics(db, metrics, 'dropoffs', machine_type, len(df[selection & (df['state']=='carry')]))
    
    for ix, row in df.loc[selection].iterrows(): # for each bot that reached a new waypoint change its behaviour
        if df.loc[ix,'state'] == 'empty':
            print('empty bot changed to carry')
            # alex adjusted
            currentBot = df.loc[ix]                                
            if USE_BIGCHAINDB:
                createAsset(master_sender_private_key,master_sender_public_key) # mockup sender initialization, creation of paackage to take for bot
                transferAsset(master_sender_public_key,master_sender_private_key,currentBot.public_key,currentBot.private_key) # transfer parcel from master_sender to current robot
            df.loc[ix,'state'] = 'carry'
        elif df.loc[ix,'state'] == 'carry':
            print('carry bot changed to empty')
            # alex adjusted
            currentBot = df.loc[ix]                
            if USE_BIGCHAINDB:
                transferAsset(currentBot.public_key,currentBot.private_key, master_receiver_public_key,master_receiver_private_key) # transfer parcel from current robot to master receiver
            df.loc[ix,'state'] = 'empty'
                
# alex end       
                targets) in enumerate(train_loader):
        step_num = utils.step_num(epoch, batch, train_loader)

        images_lowres, images_fullres, targets = utils.device(
            [images_lowres, images_fullres, targets])

        predictions = model(images_lowres, images_fullres)
        loss = F.mse_loss(predictions, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_metrics = utils.update_metrics(
            running_metrics, {
                'mse': loss.item(),
                'psnr': compute_psnr_from_mse(loss)
            }, print_freq)

        if (step_num + 1) % print_freq == 0:
            utils.log_to_tensorboard(writer, running_metrics, step_num)
            utils.print_metrics(running_metrics, step_num,
                                n_epochs * len(train_loader))
            running_metrics = {l: 0 for l in running_metrics}

    print('Finished epoch %d\n' % epoch)

    print('Evaluation')
    eval_metrics = evaluate_dataset(model, test_loader)
    utils.log_to_tensorboard(writer, eval_metrics, step_num, log_prefix='test')
    utils.print_metrics(eval_metrics, step_num, n_epochs * len(train_loader))
Example #9
0
def trainepoch(epoch):
    xp.Epoch.update(1).log()
    print('\nTRAINING : Epoch ' + str(epoch))
    nli_net.train()
    # shuffle the data
    permutation = np.random.permutation(len(train['s1']))

    s1 = train['s1'][permutation]
    s2 = train['s2'][permutation]
    target = train['label'][permutation]

    if epoch > 1 and params.opt == 'sgd':
        optimizer.param_groups[0]['lr'] *= params.decay
        optimizer.eta = optimizer.param_groups[0]['lr']

    xp.Timer_Train.reset()
    stats = {}

    for stidx in tqdm(range(0, len(s1), params.batch_size),
                      disable=not params.tqdm,
                      desc='Train Epoch',
                      leave=False):
        # prepare batch
        s1_batch, s1_len = get_batch(s1[stidx:stidx + params.batch_size],
                                     word_vec)
        s2_batch, s2_len = get_batch(s2[stidx:stidx + params.batch_size],
                                     word_vec)
        s1_batch, s2_batch = s1_batch.cuda(), s2_batch.cuda()
        tgt_batch = torch.LongTensor(target[stidx:stidx +
                                            params.batch_size]).cuda()

        # model forward
        scores = nli_net((s1_batch, s1_len), (s2_batch, s2_len))
        with set_smoothing_enabled(params.smooth_svm):
            loss = loss_fn(scores, tgt_batch)

        # backward
        optimizer.zero_grad()
        loss.backward()
        if params.opt != 'dfw':
            adapt_grad_norm(nli_net, params.max_norm)
        # necessary information for the step-size of some optimizers -> provide closure
        optimizer.step(lambda: float(loss))

        # track statistics for monitoring
        stats['loss'] = float(loss_fn(scores, tgt_batch))
        stats['acc'] = float(accuracy(scores, tgt_batch))
        stats['gamma'] = float(optimizer.gamma)
        stats['size'] = float(tgt_batch.size(0))
        update_metrics(xp, stats)

    xp.Eta.update(optimizer.eta)
    xp.Reg.update(regularization(nli_net, params.l2))
    xp.Obj_Train.update(xp.Reg.value + xp.Loss_Train.value)
    xp.Timer_Train.update()

    print('results : epoch {0} ; mean accuracy train : {1}'.format(
        epoch, xp.acc_train))
    print('\nEpoch: [{0}] (Train) \t'
          '({timer:.2f}s) \t'
          'Obj {obj:.3f}\t'
          'Loss {loss:.3f}\t'
          'Acc {acc:.2f}%\t'.format(int(xp.Epoch.value),
                                    timer=xp.Timer_Train.value,
                                    acc=xp.Acc_Train.value,
                                    obj=xp.Obj_Train.value,
                                    loss=xp.Loss_Train.value))

    log_metrics(xp)
def evaluate(model, loss_fn, test_loader, params, sample=True):
    '''
    Evaluate the model on the test set.
    Args:
        model: (torch.nn.Module) the Deep AR model
        loss_fn: a function that takes outputs and labels per timestep, and then computes the loss for the batch
        test_loader: load test data and labels
        params: (Params) hyperparameters
        sample: (boolean) do ancestral sampling or directly use output mu from last time step
    '''
    model.eval()
    with torch.no_grad():

        summary_metric = {}
        raw_metrics = utils.init_metrics(sample=sample)

        # Test_loader: 
        # test_batch ([batch_size, train_window, 1+cov_dim]): z_{0:T-1} + x_{1:T}, note that z_0 = 0;
        # id_batch ([batch_size]): one integer denoting the time series id;
        # v ([batch_size, 2]): scaling factor for each window;
        # labels ([batch_size, train_window]): z_{1:T}.

        result_mu = []
        result_sigma = []
        for i, (test_batch, id_batch, v, labels) in enumerate(test_loader):
            test_batch = test_batch.permute(1, 0, 2).to(torch.float32).to(params.device)
            id_batch = id_batch.unsqueeze(0).to(params.device)
            v_batch = v.to(torch.float32).to(params.device)
            labels = labels.to(torch.float32).to(params.device)
            batch_size = test_batch.shape[1]
            input_mu = torch.zeros(batch_size, params.test_predict_start, device=params.device) # scaled
            input_sigma = torch.zeros(batch_size, params.test_predict_start, device=params.device) # scaled
            hidden = model.init_hidden(batch_size)
            cell = model.init_cell(batch_size)

            for t in range(params.test_predict_start):  # 先计算encoder部分
                # if z_t is missing, replace it by output mu from the last time step
                # 如果z_t缺失,用前一步预测值代替真实值作为输入
                zero_index = (test_batch[t, :, 0] == 0)
                if t > 0 and torch.sum(zero_index) > 0:
                    test_batch[t, zero_index, 0] = mu[zero_index]

                mu, sigma, hidden, cell = model(test_batch[t].unsqueeze(0), id_batch, hidden, cell)
                input_mu[:, t] = v_batch[:, 0] * mu + v_batch[:, 1]  # v_batch[:, 1] == 0, useless
                input_sigma[:, t] = v_batch[:, 0] * sigma
            
            if not params.one_step:
                test_batch[params.test_predict_start, :, 0] = mu
            
            # 计算decoder部分
            if sample:
                samples, sample_mu, sample_sigma = model.test(test_batch, v_batch, id_batch, hidden, cell, sampling=True, one_step=params.one_step)
                raw_metrics = utils.update_metrics(raw_metrics, input_mu, input_sigma, sample_mu, labels, params.test_predict_start, samples, relative = params.relative_metrics)
            else:
                sample_mu, sample_sigma = model.test(test_batch, v_batch, id_batch, hidden, cell, one_step=params.one_step)
                raw_metrics = utils.update_metrics(raw_metrics, input_mu, input_sigma, sample_mu, labels, params.test_predict_start, relative = params.relative_metrics)
            result_mu.append(sample_mu)
            result_sigma.append(sample_sigma)

        summary_metric = utils.final_metrics(raw_metrics, sampling=sample)
        metrics_string = '; '.join('{}: {:05.3f}'.format(k, v) for k, v in summary_metric.items())
        # print('test metrics: ' + metrics_string)
    return summary_metric, result_mu, result_sigma