Beispiel #1
0
        pipe_command = 'python ctr_dataset_reader.py'

        dataset.init(batch_size=batch_size,
                     use_var=self.feeds,
                     pipe_command=pipe_command,
                     thread_num=thread_num)

        dataset.set_filelist(filelist)

        for epoch_id in range(1):
            pass_start = time.time()
            dataset.set_filelist(filelist)
            exe.train_from_dataset(program=fluid.default_main_program(),
                                   dataset=dataset,
                                   fetch_list=[self.avg_cost],
                                   fetch_info=["cost"],
                                   print_period=2,
                                   debug=int(os.getenv("Debug", "0")))
            pass_time = time.time() - pass_start

        if os.getenv("SAVE_MODEL") == "1":
            model_dir = tempfile.mkdtemp()
            fleet.save_inference_model(exe, model_dir,
                                       [feed.name for feed in self.feeds],
                                       self.avg_cost)
            self.check_model_right(model_dir)
            shutil.rmtree(model_dir)

if __name__ == "__main__":
    runtime_main(TestDistCTR2x2)
        fleet.init_worker()
        exe.run(fluid.default_startup_program())
        batch_size = 4
        # reader
        train_reader = paddle.batch(fake_simnet_reader(),
                                    batch_size=batch_size)
        self.reader.decorate_sample_list_generator(train_reader)
        for epoch_id in range(1):
            self.reader.start()
            try:
                pass_start = time.time()
                while True:
                    loss_val = exe.run(program=fluid.default_main_program(),
                                       fetch_list=[self.avg_cost.name])
                    loss_val = np.mean(loss_val)
                    message = "TRAIN ---> pass: {} loss: {}\n".format(
                        epoch_id, loss_val)
                    fleet_util.print_on_rank(message, 0)

                pass_time = time.time() - pass_start
            except fluid.core.EOFException:
                self.reader.reset()
        fleet.stop_worker()

    def do_dataset_training(self, fleet):
        pass


if __name__ == "__main__":
    runtime_main(TestDistSimnetBow2x2)
Beispiel #3
0
        dataset.set_filelist(filelist)
        dataset._set_thread(thread_num)

        for epoch_id in range(1):
            pass_start = time.time()
            dataset.set_filelist(filelist)
            exe.train_from_dataset(
                program=fleet.main_program,
                dataset=dataset,
                fetch_list=[self.avg_cost],
                fetch_info=["cost"],
                print_period=2,
                debug=int(os.getenv("Debug", "0")))
            pass_time = time.time() - pass_start

        if os.getenv("SAVE_MODEL") == "1":
            model_dir = tempfile.mkdtemp()
            fleet.save_inference_model(exe, model_dir,
                                       [feed.name for feed in self.feeds],
                                       self.avg_cost)
            self.check_model_right(model_dir)
            if fleet.is_first_worker():
                fleet.save_persistables(executor=exe, dirname=model_dir)
            shutil.rmtree(model_dir)

        fleet.stop_worker()


if __name__ == "__main__":
    runtime_main(TestDistGpuPsCTR2x2)