コード例 #1
0
ファイル: train.py プロジェクト: yrpang/mindspore
    else:
        rank = 0
        device_num = 1
    mindrecord_file = args.dataset_path
    if not os.path.exists(mindrecord_file):
        print("dataset file {} not exists, please check!".format(
            mindrecord_file))
        raise ValueError(mindrecord_file)
    dataset = create_gru_dataset(epoch_count=config.num_epochs,
                                 batch_size=config.batch_size,
                                 dataset_path=mindrecord_file,
                                 rank_size=device_num,
                                 rank_id=rank)
    dataset_size = dataset.get_dataset_size()
    print("dataset size is {}".format(dataset_size))
    network = Seq2Seq(config)
    network = GRUWithLossCell(network)
    lr = dynamic_lr(config, dataset_size)
    opt = Adam(network.trainable_params(), learning_rate=lr)
    scale_manager = DynamicLossScaleManager(
        init_loss_scale=config.init_loss_scale_value,
        scale_factor=config.scale_factor,
        scale_window=config.scale_window)
    update_cell = scale_manager.get_update_cell()
    netwithgrads = GRUTrainOneStepWithLossScaleCell(network, opt, update_cell)

    time_cb = TimeMonitor(data_size=dataset_size)
    loss_cb = LossCallBack(rank_id=rank)
    cb = [time_cb, loss_cb]
    #Save Checkpoint
    if config.save_checkpoint:
コード例 #2
0
parser.add_argument("--file_format",
                    type=str,
                    choices=["AIR", "MINDIR"],
                    default="MINDIR",
                    help="file format.")
parser.add_argument('--ckpt_file',
                    type=str,
                    required=True,
                    help='ckpt file path')
args = parser.parse_args()

context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, reserve_class_name_in_scope=False, \
                    device_id=args.device_id, save_graphs=False)

if __name__ == "__main__":
    network = Seq2Seq(config, is_training=False)
    network = GRUInferCell(network)
    network.set_train(False)
    if args.ckpt_file != "":
        parameter_dict = load_checkpoint(args.ckpt_file)
        load_param_into_net(network, parameter_dict)

    source_ids = Tensor(
        np.random.uniform(0.0,
                          1e5,
                          size=[config.eval_batch_size,
                                config.max_length]).astype(np.int32))
    target_ids = Tensor(
        np.random.uniform(0.0,
                          1e5,
                          size=[config.eval_batch_size,
コード例 #3
0
ファイル: eval.py プロジェクト: mindspore-ai/course
                        default='',
                        help='checkpoint path.')
    args = parser.parse_args()

    context.set_context(
        mode=context.GRAPH_MODE,  #PYNATIVE_MODE,#GRAPH_MODE,
        save_graphs=False,
        device_target='Ascend')

    rank = 0
    device_num = 1
    ds_eval = create_dataset(args.dataset_path,
                             cfg.eval_batch_size,
                             is_training=False)

    network = Seq2Seq(cfg, is_train=False)
    network = InferCell(network, cfg)
    network.set_train(False)
    parameter_dict = load_checkpoint(args.checkpoint_path)
    load_param_into_net(network, parameter_dict)
    model = Model(network)

    with open(os.path.join(args.dataset_path, "en_vocab.txt"),
              'r',
              encoding='utf-8') as f:
        data = f.read()
    en_vocab = list(data.split('\n'))

    with open(os.path.join(args.dataset_path, "ch_vocab.txt"),
              'r',
              encoding='utf-8') as f:
コード例 #4
0
ファイル: train.py プロジェクト: mindspore-ai/course
                        type=str,
                        default='./preprocess',
                        help='dataset path.')
    parser.add_argument('--ckpt_save_path',
                        type=str,
                        default='./',
                        help='checkpoint save path.')
    args = parser.parse_args()

    context.set_context(mode=context.GRAPH_MODE,
                        save_graphs=False,
                        device_target='Ascend')

    ds_train = create_dataset(args.dataset_path, cfg.batch_size)

    network = Seq2Seq(cfg)
    network = WithLossCell(network, cfg)
    optimizer = nn.Adam(network.trainable_params(),
                        learning_rate=cfg.learning_rate,
                        beta1=0.9,
                        beta2=0.98)
    model = Model(network, optimizer=optimizer)

    loss_cb = LossMonitor()
    config_ck = CheckpointConfig(
        save_checkpoint_steps=cfg.save_checkpoint_steps,
        keep_checkpoint_max=cfg.keep_checkpoint_max)
    ckpoint_cb = ModelCheckpoint(prefix="gru",
                                 directory=args.ckpt_save_path,
                                 config=config_ck)
    time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
コード例 #5
0
def run_gru_eval():
    """
    Transformer evaluation.
    """
    parser = argparse.ArgumentParser(description='GRU eval')
    parser.add_argument(
        "--device_target",
        type=str,
        default="Ascend",
        help="device where the code will be implemented, default is Ascend")
    parser.add_argument('--device_id',
                        type=int,
                        default=0,
                        help='device id of GPU or Ascend, default is 0')
    parser.add_argument('--device_num',
                        type=int,
                        default=1,
                        help='Use device nums, default is 1')
    parser.add_argument('--ckpt_file',
                        type=str,
                        default="",
                        help='ckpt file path')
    parser.add_argument("--dataset_path",
                        type=str,
                        default="",
                        help="Dataset path, default: f`sns.")
    args = parser.parse_args()

    context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, reserve_class_name_in_scope=False, \
        device_id=args.device_id, save_graphs=False)
    dataset = create_gru_dataset(epoch_count=config.num_epochs, batch_size=config.eval_batch_size, \
        dataset_path=args.dataset_path, rank_size=args.device_num, rank_id=0, do_shuffle=False, is_training=False)
    dataset_size = dataset.get_dataset_size()
    print("dataset size is {}".format(dataset_size))
    network = Seq2Seq(config, is_training=False)
    network = GRUInferCell(network)
    network.set_train(False)
    if args.ckpt_file != "":
        parameter_dict = load_checkpoint(args.ckpt_file)
        load_param_into_net(network, parameter_dict)
    model = Model(network)

    predictions = []
    source_sents = []
    target_sents = []
    eval_text_len = 0
    for batch in dataset.create_dict_iterator(output_numpy=True, num_epochs=1):
        source_sents.append(batch["source_ids"])
        target_sents.append(batch["target_ids"])
        source_ids = Tensor(batch["source_ids"], mstype.int32)
        target_ids = Tensor(batch["target_ids"], mstype.int32)
        predicted_ids = model.predict(source_ids, target_ids)
        print("predicts is ", predicted_ids.asnumpy())
        print("target_ids is ", target_ids)
        predictions.append(predicted_ids.asnumpy())
        eval_text_len = eval_text_len + 1

    f_output = open(config.output_file, 'w')
    f_target = open(config.target_file, "w")
    for batch_out, true_sentence in zip(predictions, target_sents):
        for i in range(config.eval_batch_size):
            target_ids = [str(x) for x in true_sentence[i].tolist()]
            f_target.write(" ".join(target_ids) + "\n")
            token_ids = [str(x) for x in batch_out[i].tolist()]
            f_output.write(" ".join(token_ids) + "\n")
    f_output.close()
    f_target.close()