Example #1
0
    )
    loss_net = NetWithLossClass(model, hparams)
    lr = get_lr(hparams.optimizer_params["lr"], hparams.nepochs,
                step_size_per_epoch)
    lr = Tensor(lr)

    if args.checkpoint != '':
        param_dict = load_checkpoint(args.pre_trained_model_path)
        load_param_into_net(model, param_dict)
        print('Successfully loading the pre-trained model')

    weights = model.trainable_params()
    optimizer = Adam(weights, learning_rate=lr, loss_scale=1024.)
    train_net = TrainOneStepCell(loss_net, optimizer)

    model = Model(train_net)
    lr_cb = Monitor(lr)
    callback_list = [lr_cb]
    if args.is_distributed:
        ckpt_path = os.path.join(args.checkpoint_dir,
                                 'ckpt_' + str(get_rank()) + '/')
    else:
        ckpt_path = args.checkpoint_dir
    config_ck = CheckpointConfig(save_checkpoint_steps=step_size_per_epoch,
                                 keep_checkpoint_max=10)
    ckpt_cb = ModelCheckpoint(prefix='wavenet',
                              directory=ckpt_path,
                              config=config_ck)
    callback_list.append(ckpt_cb)
    model.train(hparams.nepochs, data_loaders, callbacks=callback_list)
Example #2
0
                step_size_per_epoch)
    lr = Tensor(lr)

    if args.checkpoint != '':
        param_dict = load_checkpoint(args.pre_trained_model_path)
        load_param_into_net(model, param_dict)
        print('Successfully loading the pre-trained model')

    weights = model.trainable_params()
    optimizer = Adam(weights, learning_rate=lr, loss_scale=1024.)
    train_net = TrainOneStepCell(loss_net, optimizer)

    model = Model(train_net)
    lr_cb = Monitor(lr)
    callback_list = [lr_cb]
    if args.is_distributed:
        ckpt_path = os.path.join(args.checkpoint_dir,
                                 'ckpt_' + str(get_rank()) + '/')
    else:
        ckpt_path = args.checkpoint_dir
    config_ck = CheckpointConfig(save_checkpoint_steps=step_size_per_epoch,
                                 keep_checkpoint_max=10)
    ckpt_cb = ModelCheckpoint(prefix='wavenet',
                              directory=ckpt_path,
                              config=config_ck)
    callback_list.append(ckpt_cb)
    model.train(hparams.nepochs,
                data_loaders,
                callbacks=callback_list,
                dataset_sink_mode=False)