Example #1
0
 def validation(self, cfg, mod, opt, val, chk):
     print_time('Start VALID')
     hyp_data = []
     ref_data = []
     loss_total = 0
     Iter = 0
     for (src_batch, tgt_batch, ref_batch, raw_src_batch, raw_tgt_batch,
          len_src_batch, len_tgt_batch) in val.minibatches():
         if cfg.cuda:
             src_batch = src_batch.cuda()
             tgt_batch = tgt_batch.cuda()
             ref_batch = ref_batch.cuda()
         dec_outputs, dec_output_words = mod(
             src_batch, tgt_batch, len_src_batch,
             len_tgt_batch)  ### forward  returns: [T,B,V] [T,B]
         loss = F.nll_loss(
             dec_outputs.permute(1, 0,
                                 2).contiguous().view(-1, cfg.tvoc.size),
             ref_batch.contiguous().view(-1),
             ignore_index=cfg.tvoc.idx_pad)  #loss normalized by word
         loss_total += loss.item()
         Iter += 1
         hyp_data.extend(dec_output_words.permute(1, 0))
         ref_data.extend(ref_batch)
     #update learning rate
     lr = opt.update_lr(loss_total)
     if Iter > 0:
         print_time('VALID iter:{} loss={:.4f} bleu={:.2f}'.format(
             cfg.n_iters_sofar, loss_total / Iter,
             self.get_bleu(hyp_data, ref_data, cfg)))
         chk.save(cfg, mod, opt, loss_total / Iter)
     return lr
Example #2
0
def post_json_files(root):
    """
    Post json objects in a designated directory to BuildingOS.

    Params:
        root string
    """
    json_dir = defaults.json_dir(root)
    archive = defaults.json_archive(root)
    post_url = defaults.BOS_URL

    json_files = utils.get_files_in_dir(json_dir)
    if not json_files:
        utils.warn('No JSON files to process. Terminating')
        exit()

    utils.print_time('LOADER START')
    for json_file in json_files:
        print('Posting file: %s ...' % (json_file)),
        with open(json_file, 'rb') as jf:
            payload = {'data': jf}
            response = requests.post(post_url, files=payload)
            print('done')

            print('Server response: %s' % (response.text))

        utils.move(json_file, archive)

    utils.print_time('LOADER END')
Example #3
0
    def load(self, par, file=None):
        if file is None:  ### load the most recent checkpoint in self.path
            if not os.path.exists(self.path):
                sys.exit('error: no experiments found in {}'.format(self.path))
            all_saves = sorted(glob.glob(self.path + '/checkpoint_*.pt'),
                               reverse=True)
            if len(all_saves) == 0:
                sys.exit('error: no checkpoint found in dir={}'.format(
                    self.path))
            checkpoint = all_saves[0]
        else:  ### load the checkpoint given in file
            if not os.path.exists(file):
                sys.exit('error: no checkpoint found in {}'.format(file))
            checkpoint = file
        chk = torch.load(checkpoint)
        ### load cfg
        cfg = chk['cfg']
        cfg.update_par(par)
        ### load model
        mod = Model(cfg)  #builds a model using cfg options
        mod.load_state_dict(chk['mod_state'])  #loads the model saved
        if cfg.cuda: mod.cuda()  ### move to GPU
        ### load optimizer
        opt = Optimizer(cfg, mod)
        opt.optimizer.load_state_dict(chk['opt_state'])

        print_time("Loaded It={} {}".format(cfg.n_iters_sofar, checkpoint))
        return cfg, mod, opt
def create_json(root):
    """
    Create the json file containing reading data.

    Params:
        root string
    """
    data_dir = defaults.downloads(root)
    output_dir = defaults.json_dir(root)
    archive = defaults.data_archive(root)

    catalog = []
    data = []
    json_file = {}

    data_files = utils.get_files_in_dir(data_dir)
    if not data_files:
        utils.warn('No csv files to process. Terminating')
        exit()

    utils.print_time('PROCESSOR START')
    print('Begin JSON file generation')
    for data_file in data_files:
        with open(data_file, 'rb') as f:
            reader = csv.reader(f)
            meterId, meterName = reader.next()

            print('Processing meterId %s ...' % (meterId)),

            info = {'meterId': meterId, 'meterName': meterName}
            catalog.append(info)

            for row in reader:
                ts = row[0]
                val = float(row[1])
                reading = {'timestamp': ts,
                           'value': val,
                           'meterId': meterId}
                data.append(reading)

            print('done')
        utils.move(data_file, archive)

    json_file['datasource'] = defaults.URI
    json_file['meterCatalog'] = catalog
    json_file['readings'] = data

    print('End JSON file generation')

    curr_dt = datetime.now()
    json_fname = 'dump_%s.json' % (utils.format_dt(curr_dt))
    save_path = os.path.join(output_dir, json_fname)

    print('Writing JSON to file %s ...' % (save_path)),
    with open(save_path, 'wb') as out:
        json.dump(json_file, out)
        print('done')

    utils.print_time('PROCESSOR END')
Example #5
0
 def save(self, cfg, mod, opt, loss):
     if not os.path.exists(self.path): os.makedirs(self.path)
     date_time = time.strftime('%Y%m%d-%H%M%S', time.localtime())
     checkpoint = os.path.join(
         self.path,
         'checkpoint_{}_{:0>6}_{}.pt'.format(date_time, cfg.n_iters_sofar,
                                             "{:.5f}".format(loss)[0:7]))
     chk = {
         'mod_state': mod.state_dict(),
         'opt_state': opt.state_dict(),
         'cfg': cfg
     }
     torch.save(chk, checkpoint)
     print_time("Saved checkpoint [{}]".format(checkpoint))
Example #6
0
    def __init__(self, cfg, mod, tst):
        self.cfg = cfg
        ini_time = time.time()
        print_time('Start TEST')
        with torch.no_grad():
            for val_iter, (src_batch, tgt_batch, ref_batch, raw_src_batch,
                           raw_tgt_batch, len_src_batch,
                           len_tgt_batch) in enumerate(tst.minibatches()):
                if cfg.cuda:
                    src_batch = src_batch.cuda()
                    tgt_batch = tgt_batch.cuda()

                _, hyp_batch = mod(src_batch, tgt_batch, len_src_batch,
                                   len_tgt_batch)  ### forward
                self.display(src_batch, ref_batch, hyp_batch)

        print_time('End of TEST seconds={:.2f})\n'.format(time.time() -
                                                          ini_time))
 def pretrainpca(self):
     print("预训练开始!")
     trainset = VideoSet(self.opt, state='train')
     trainloader = DataLoader(trainset,
                              batch_size=2048,
                              shuffle=True,
                              num_workers=self.opt.num_workers,
                              drop_last=False)
     times = []
     start = time.time()
     times.append(start)
     for ii, (train_input, _) in enumerate(trainloader):
         time1 = time.time()
         print_time(times[-1], time1, "第%d次迭代训练的数据读取" % (ii + 1))
         train_input = self.statistic_feature_extraction(train_input)
         print("开始第%d次迭代训练" % (ii + 1))
         self.pca.fit(train_input)
         time2 = time.time()
         print("第%d次迭代训练完成" % (ii + 1), end=",")
         print_time(time1, time2)
         times.append(time1)
         times.append(time2)
         print()
     print_time(times[0], times[-1], "整个PCA模型训练过程")
     self.pca.save(self.opt.incremental_pca_params_path)
     self.PCA_state = True
Example #8
0
def main():

    par = Params(sys.argv)
    random.seed(par.seed)
    torch.manual_seed(par.seed)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(par.seed)

    if par.trn and par.val:
        chk = Checkpoint(par.dir)

        if chk.contains_model:  ####### resume training ####################################
            cfg, mod, opt = chk.load(par)  ### also moves to GPU if cfg.cuda
            #            cfg.update_par(par) ### updates par in cfg
            print_time('Learning [resume It={}]...'.format(cfg.n_iters_sofar))

        else:  ######################## training from scratch ##############################
            cfg = Config(par)  ### reads cfg and par (reads vocabularies)
            mod = Model(cfg)
            if cfg.cuda: mod.cuda()  ### moves to GPU
            opt = Optimizer(cfg, mod)  #build Optimizer
            print_time('Learning [from scratch]...')

        trn = Dataset(par.trn,
                      cfg.svoc,
                      cfg.tvoc,
                      par.batch_size,
                      par.max_src_len,
                      par.max_tgt_len,
                      do_shuffle=True,
                      do_filter=True,
                      is_test=False)
        val = Dataset(par.val,
                      cfg.svoc,
                      cfg.tvoc,
                      par.batch_size,
                      par.max_src_len,
                      par.max_tgt_len,
                      do_shuffle=True,
                      do_filter=True,
                      is_test=True)
        Training(cfg, mod, opt, trn, val, chk)

    elif par.tst:  #################### inference ##########################################
        chk = Checkpoint()
        cfg, mod, opt = chk.load(par, par.chk)
        #        cfg.update_par(par) ### updates cfg options with pars
        tst = Dataset(par.tst,
                      cfg.svoc,
                      cfg.tvoc,
                      par.batch_size,
                      0,
                      0,
                      do_shuffle=False,
                      do_filter=False,
                      is_test=True)
        print_time('Inference [model It={}]...'.format(cfg.n_iters_sofar))
        Inference(cfg, mod, tst)
Example #9
0
    def __init__(self, cfg, mod, opt, trn, val, chk):
        if cfg.n_iters_sofar is None: cfg.n_iters_sofar = 0
        ini_time = time.time()
        print_time('Start TRAIN')
        lr = cfg.par.lr
        loss_total_N_iters = 0  # Reset every [print_every]
        Iter = 0
        for (src_batch, tgt_batch, ref_batch, raw_src_batch, raw_tgt_batch,
             len_src_batch, len_tgt_batch) in trn.minibatches():
            assert (len(src_batch) == len(tgt_batch) == len(ref_batch) ==
                    len(ref_batch) == len(raw_src_batch) == len(raw_tgt_batch)
                    == len(len_src_batch))
            if cfg.cuda:
                src_batch = src_batch.cuda()
                tgt_batch = tgt_batch.cuda()
                ref_batch = ref_batch.cuda()
            dec_outputs, dec_output_words = mod(
                src_batch, tgt_batch, len_src_batch,
                len_tgt_batch)  # forward returns: [T,B,V] [T,B]
            dec_outputs = dec_outputs.permute(1, 0, 2).contiguous().view(
                -1, cfg.tvoc.size)
            ref_batch = ref_batch.contiguous().view(-1)
            loss = F.nll_loss(
                dec_outputs, ref_batch,
                ignore_index=cfg.tvoc.idx_pad)  #loss normalized by word
            loss_total_N_iters += loss.item()
            mod.zero_grad()  # reset gradients
            loss.backward()  # Backward propagation
            opt.step()
            cfg.n_iters_sofar += 1
            Iter += 1
            if Iter % cfg.par.print_every == 0:
                print_time('TRAIN iter:{} lr={:.5f} loss={:.4f}'.format(
                    cfg.n_iters_sofar, lr,
                    loss_total_N_iters / cfg.par.print_every))
                loss_total_N_iters = 0
            if Iter % cfg.par.valid_every == 0:
                lr = self.validation(cfg, mod, opt, val, chk)
            if Iter >= cfg.par.n_iters:
                break

        print_time('End of TRAIN seconds={:.2f}'.format(time.time() -
                                                        ini_time))
Example #10
0
def train(**kwargs):
    # 根据命令行参数更新配置
    opt = DefaultConfig()
    opt.parse(kwargs)
    print("参数配置完成")

    # 优化器
    learning_rate = opt.learning_rate
    # optimizer默认是Adam
    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate,
                                       beta1=0.5,
                                       beta2=0.9)
    if opt.optimizer_type == "SGD":
        optimizer = tf.train.GradientDescentOptimizer(
            learning_rate=learning_rate)
    elif opt.optimizer_type == "Momentum":
        momentum = opt.momentum
        optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                                               momentum=momentum)
    elif opt.optimizer_type == "Adam":
        optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate,
                                           beta1=0.5,
                                           beta2=0.9)

    # 建立静态图
    with tf.Graph().as_default():
        with tf.name_scope("inputs"):
            inputs = tf.placeholder("float", [None, 24, 2, 2],
                                    name="model_input")
            labels = tf.placeholder("float", [None, 72, 14, 2], name="labels")

        # 定义模型,统计并分类需要训练的模型参数
        model = []
        if opt.model_type == 1:  # 反卷积
            gmodel = GModel(opt.batch_size, opt.normal_type, True,
                            "generate_model")
            model.append(gmodel)
        elif opt.model_type == 2:  # 反卷积+可学习pooling
            gmodel = GModel(opt.batch_size, opt.normal_type, True,
                            "generate_model")
            model.append(gmodel)
            learningpoolingmodel = LearningPoolingModel(
                opt.batch_size, opt.normal_type, True, opt.model_2_layers,
                "learning_pooling_model")
            model.append(learningpoolingmodel)
        elif opt.model_type == 3:  # 反卷积+GAN
            gmodel = GModel(opt.batch_size, opt.normal_type, True,
                            "generate_model")
            model.append(gmodel)
            dmodel = DModel(opt.batch_size, opt.normal_type, True,
                            opt.GAN_type, "discriminate_model")
            model.append(dmodel)
        # print(model)

        # 统计并分类需要训练的参数
        # 由于下面加上了对tf.GraphKeys.UPDATE_OPS的依赖,所以get_vars函数要加到calculate_loss函数后面
        # 不然就会导致all_vars为空
        def get_vars():
            all_vars = tf.trainable_variables()
            # print(all_vars)
            gg_vars = [var for var in all_vars if "generate_model" in var.name]
            dd_vars = [
                var for var in all_vars if "discriminate_mode" in var.name
            ]
            ll_pp_vars = [
                var for var in all_vars if "learning_pooling_model" in var.name
            ]
            return gg_vars, dd_vars, ll_pp_vars

        # 加上对update_ops的依赖,不然BN就会出现问题!
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.device(opt.gpu_num):
            if opt.model_type == 1:  # 反卷积
                pre_loss, mse, pred = model[0].calculate_loss(inputs, labels)
                g_vars, _, _ = get_vars()
                with tf.control_dependencies(update_ops):
                    train_ops = optimizer.minimize(pre_loss, var_list=g_vars)
            elif opt.model_type == 2:  # 反卷积+可学习pooling
                _, mse, pred = model[0].calculate_loss(inputs, labels)
                l_p_loss = model[1].calculate_loss(pred, labels,
                                                   opt.model_2_scale)
                g_vars, _, l_p_vars = get_vars()
                with tf.control_dependencies(update_ops):
                    train_ops = optimizer.minimize(l_p_loss,
                                                   var_list=g_vars + l_p_vars)
            elif opt.model_type == 3:  # 反卷积+GAN
                pre_loss, mse, pred = model[0].calculate_loss(inputs, labels)
                gen_loss, dis_loss = model[1].calculate_loss(pred, labels)
                g_vars, d_vars, _ = get_vars()
                with tf.control_dependencies(update_ops):
                    # D网络的训练 --> G网络的训练 ——> 先验网络(也就是G网络)的训练
                    d_train_ops = optimizer.minimize(dis_loss, var_list=d_vars)
                    g_train_ops = optimizer.minimize(gen_loss, var_list=g_vars)
                    pre_train_ops = optimizer.minimize(pre_loss,
                                                       var_list=g_vars)

        tf.summary.scalar("MSE", mse)

        tf.add_to_collection("input_batch", inputs)
        tf.add_to_collection("predictions", pred)

        saver = tf.train.Saver()
        init = tf.global_variables_initializer()

        # 开始训练
        config = tf.ConfigProto()
        config.gpu_options.per_process_gpu_memory_fraction = opt.per_process_gpu_memory_fraction
        with tf.Session(config=config) as sess:
            # 首先是参数的初始化
            sess.run(init)

            if opt.model_type == 1:
                model_type = "model_1"
            elif opt.model_type == 2:
                model_type = "model_2_" + str(opt.model_2_layers)
            elif opt.model_type == 3:
                model_type = "model_3"
            summary_path = opt.summary_path + model_type + "\\data_SNR_" + str(
                opt.SNR)
            writer = tf.summary.FileWriter(summary_path, sess.graph)
            merge_ops = tf.summary.merge_all()

            start = time.time()

            data_path = opt.train_data_path + "data_SNR_" + str(opt.SNR)
            # 定义训练集dataset
            train_dataset = CSISet(data_path,
                                   opt.batch_size,
                                   True,
                                   state="train")
            # 定义验证集dataset
            validation_dataset = CSISet(data_path,
                                        opt.batch_size,
                                        True,
                                        state="validation")

            # 保存训练集和验证集的中间值,用于后续的画图
            train_mse_for_plot = []
            valid_mse_for_plot = []

            for num in range(opt.num_epoch):
                # 判断是否需要改变学习率
                if opt.optimizer_type == "Momentum" and (
                        num % opt.learning_rate_change_epoch) == 0:
                    learning_rate *= opt.learning_rate_decay
                    print("第%i个epoch开始,当前学习率是%f" % (num, learning_rate))

                for ii, (batch_x,
                         batch_y) in enumerate(train_dataset.get_data()):
                    if opt.model_type == 1 or opt.model_type == 2:
                        _, train_mse, summary = sess.run(
                            [train_ops, mse, merge_ops],
                            feed_dict={
                                inputs: batch_x,
                                labels: batch_y
                            })
                    elif opt.model_type == 3:
                        _, _, _, train_mse, summary = sess.run([
                            d_train_ops, g_train_ops, pre_train_ops, mse,
                            merge_ops
                        ],
                                                               feed_dict={
                                                                   inputs:
                                                                   batch_x,
                                                                   labels:
                                                                   batch_y
                                                               })
                    writer.add_summary(summary)

                    if (ii + 1) % 1000 == 0:
                        print("epoch-%d, batch_num-%d: 当前batch训练数据误差是%f" %
                              (num + 1, ii + 1, train_mse))

                        # 每1000个batch就在验证集上测试一次
                        validate_mse = 0
                        jj = 1
                        for (validate_x,
                             validate_y) in validation_dataset.get_data():
                            temp_mse = sess.run(mse,
                                                feed_dict={
                                                    inputs: validate_x,
                                                    labels: validate_y
                                                })
                            validate_mse += temp_mse
                            jj += 1
                        validate_mse = validate_mse / (jj + 1)
                        print("epoch-%d: 当前阶段验证集数据平均误差是%f" %
                              (num + 1, validate_mse))
                        train_mse_for_plot.append(train_mse)
                        valid_mse_for_plot.append(validate_mse)

            end = time.time()

            utils.print_time(start, end, "跑完" + str(opt.num_epoch) + "个epoch")

            plot_path = opt.result_path + model_type + "\\data_SNR_" + str(
                opt.SNR) + "\\train"
            utils.plot_fig(train_mse_for_plot, valid_mse_for_plot, plot_path)
            print("训练过程中最小验证误差是%f" % min(valid_mse_for_plot))

            # 保存模型文件
            model_file = opt.model_path + model_type + "\\data_SNR_" + str(
                opt.SNR) + "\\data_SNR_" + str(opt.SNR)
            model_utils.save_model(saver, sess, model_file)
Example #11
0
def train(**kwargs):
    # 根据命令行参数更新配置
    opt = DefaultConfig()
    opt.parse(kwargs)
    print("参数配置完成")

    # step1: 模型
    os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu_num
    model = getattr(models, opt.model)(opt, 1024)  # TODO:1024这个值后期可能需要用变量代替
    if opt.model == "StatisticModel" and model.PCA_state is False:
        model.pretrainpca()
    if opt.load_model_path:
        model.load(opt.load_model_path)
    if opt.use_gpu:
        model.cuda()
    print("模型加载完成")

    # step2: 数据
    train_data = VideoSet(opt, state='train')
    train_dataloader = DataLoader(train_data,
                                  batch_size=opt.batch_size,
                                  shuffle=True,
                                  num_workers=opt.num_workers)
    print("数据集准备就绪")

    # step3: 目标函数和优化器
    criterion = torch.nn.BCELoss(size_average=False)
    lr = opt.lr
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=lr,
                                 weight_decay=opt.weight_decay)

    # step4: 统计指标

    print("开始训练")
    # 训练
    start = time.time()
    for epoch in range(opt.max_epoch):
        epoch_start = time.time()
        t1 = time.time()
        for ii, (data, label) in enumerate(train_dataloader):
            # 训练模型参数
            input_data = Variable(data)
            target = Variable(label)
            # ipdb.set_trace()
            if opt.use_gpu:
                input_data = input_data.cuda()
                target = target.cuda()
            optimizer.zero_grad()
            score = model(input_data)
            loss = criterion(score, target)
            loss.backward()
            optimizer.step()
            if (ii + 1) % 100:
                t2 = time.time()
                print('-------------------------------------------------')
                print('第%d个epoch,第%d个batch的loss为%.4f' %
                      (epoch + 1, ii + 1, loss))
                print_time(t1, t2, '该batch训练')
                print('-------------------------------------------------')
                t1 = time.time()
        epoch_end = time.time()
        print_time(epoch_start, epoch_end, '第%d个epoch' % (epoch + 1))
    end = time.time()
    print_time(start, end)

    # 最后保存一次模型参数
    model.save()
Example #12
0
def run_batch(root, start, end, idx=None):
    """
    Run this script in batch mode. Download reading data whose timestamps
    lie within start and end dates.

    The date must follow the following format (Note the T between date and time):
        YYYY-MM-DDTHH:MM:SS

    where 24 hour time is used.
    
    If idx is a non-negative integer, instead download the meter at that index.
    idx is zero-indexed. If idx is greater than the number of meters, nothing
    happens; no files are downloaded. Default behavior is to download data for
    all meters.

    Params:
        root string
        start string
        end string
        idx integer
    """
    s_date = get_date(start)
    e_date = get_date(end)

    if not s_date or not e_date:
        raise ValueError('Invalid/missing dates')
    elif start > end:
        raise ValueError('Start date must come before end date')
    elif not utils.exists_dir(root):
        raise ValueError('Root directory not found')
    elif idx is not None and not is_valid_index(idx):
        raise ValueError('Index must be non-negative integer')

    creds_file = defaults.creds(root)
    cnxn_str = utils.get_cnxn_str(creds_file)
    output_dir = defaults.downloads(root)
    meter_file = defaults.meter_file(root)

    utils.print_time('GETTER START')

    with Cursor.Cursor(cnxn_str) as cursor:
        dq = get_reading_from_name_query_str()
        meters = utils.read_meter_file(meter_file)
        for i, m in enumerate(meters):
            if idx is not None and idx != i:
                continue
            ion_name = utils.get_ion_name(m)
            qid = utils.get_ion_qid(m)
            try:
                cursor.execute(dq, ion_name, qid, str(s_date), str(e_date))
            except pyodbc.Error:
                utils.error(
                    'Problem with query to get data for meter %s qid %d' %
                    (ion_name, qid))
                continue
            if cursor.rowcount == 0:
                utils.warn('No data found for meter %s qid %d' %
                           (ion_name, qid))
                continue

            meterId, meterName = utils.get_lucid_id_and_name(m)
            s_date_str = utils.make_lucid_ts(str(s_date))
            e_date_str = utils.make_lucid_ts(str(e_date))
            dl_fname = "%sT%sT%s.csv" % (meterId, s_date_str, e_date_str)
            path = os.path.join(output_dir, dl_fname)

            print('Writing data for meter %s qid %d to file: %s ...' %
                  (ion_name, qid, path)),
            with open(path, 'wb') as data_file:
                writer = csv.writer(data_file)
                writer.writerow([meterId, meterName])

                for row in cursor:
                    ts = row.TimestampUTC
                    val = row.Value
                    data_row = [utils.make_lucid_ts(ts), val]
                    writer.writerow(data_row)
                print('done')
    utils.print_time('GETTER END')
Example #13
0
def inference(**kwargs):
    # 根据命令行参数更新配置
    opt = DefaultConfig()
    opt.parse(kwargs)
    print("参数配置完成")

    if opt.model_type == 1:
        model_type = "model_1"
    elif opt.model_type == 2:
        model_type = "model_2_" + opt.model_2_layers
    elif opt.model_type == 3:
        model_type = "model_3"

    # 加载静态图
    model_files = opt.model_path + model_type + "\\data_SNR_" + str(opt.SNR)
    saver = tf.train.import_meta_graph(model_files + "\\data_SNR_" +
                                       str(opt.SNR) + ".meta")

    # 开始测试
    config = tf.ConfigProto()
    config.gpu_options.per_process_gpu_memory_fraction = opt.per_process_gpu_memory_fraction
    with tf.Session(config=config) as sess:
        # 加载参数值
        saver.restore(sess, tf.train.latest_checkpoint(model_files))

        # 定义测试集dataset
        test_dataset = CSISet(opt.test_data_path, opt.batch_size, False,
                              "test")

        data_loss = []  # 保存每个batch的发送信号和预测的发送信号之间的误差

        print("开始预测过程!")
        start_time = time.time()

        for ii, (batch_x, batch_tx,
                 batch_rx) in enumerate(test_dataset.get_data()):
            inputs = tf.get_collection("input_batch")[0]
            predictions = tf.get_collection("predictions")[0]

            # pred_H是模型预测的信道完整特性,维度是[batch, 72, 14, 2]
            # 再利用公式 H=rx/tx 和(pred_H, batch_rx)就可以得到pred_rx
            pred_H = np.squeeze(np.array(
                sess.run([predictions], feed_dict={inputs: batch_x})),
                                axis=0)
            complex_pred_H = pred_H[:, :, :, 0] + pred_H[:, :, :, 1] * 1j
            # ipdb.set_trace()
            pred_batch_tx = np.divide(batch_rx, complex_pred_H)

            pred_batch_tx[:, :5, 5:7] = 1.
            pred_batch_tx[:, 67:72, 5:7] = 1.
            batch_tx[:, :5, 5:7] = 1.
            batch_tx[:, 67:72, 5:7] = 1.

            batch_data_loss_ratio = np.mean(
                np.divide(abs(pred_batch_tx - batch_tx), abs(batch_tx)))
            # print(batch_data_loss)
            print("第%d个batch的发送信息预测平均误差是%.6f" %
                  (ii + 1, batch_data_loss_ratio))
            data_loss.append(batch_data_loss_ratio)

        result = np.mean(data_loss)
        print("信噪比为%d时模型在测试集上的平均估计误差为%.2f" % (opt.SNR, result))

        end_time = time.time()
        print_time(start_time, end_time, "整个测试过程")

        result_path = opt.result_path + model_type + "\\data_SNR_" + str(
            opt.SNR) + "\\test\\result.npy"
        np.save(result_path, data_loss)