示例#1
0
        from torch.utils.tensorboard import SummaryWriter
        summary_name = '{}s2s-e{}d{}h{}'.format(args.tensorboard_dir,
                                                args.n_enc, args.n_dec,
                                                args.n_head)
        if args.use_cnn:
            summary_name += 'cnn'
        tensorboard_writer = SummaryWriter(summary_name)

    cfg = {
        'model_path': args.model_path,
        'lr': args.lr,
        'label_smooth': args.label_smooth,
        'weight_decay': args.weight_decay,
        'teacher_force': args.teacher_force,
        'n_warmup': args.n_warmup,
        'n_const': args.n_const,
        'b_input': args.b_input,
        'b_update': args.b_update,
        'n_print': args.n_print
    }
    datasets = (tr_reader, cv_reader)
    train_model(model,
                datasets,
                args.n_epoch,
                device,
                cfg,
                loss_norm=args.loss_norm,
                grad_norm=args.grad_norm,
                fp16=args.fp16,
                tensorboard_writer=tensorboard_writer)
示例#2
0
                          time_win=args.time_win)
    cv_reader = ScpStreamReader(args.valid_scp,
                                args.valid_target,
                                downsample=args.downsample,
                                mean_sub=args.mean_sub,
                                max_len=args.max_len,
                                max_utt=args.max_utt,
                                fp16=args.fp16)

    cfg = {
        'model_path': args.model_path,
        'lr': args.lr,
        'label_smooth': args.label_smooth,
        'weight_decay': args.weight_decay,
        'teacher_force': args.teacher_force,
        'n_warmup': args.n_warmup,
        'n_const': args.n_const,
        'b_input': args.b_input,
        'b_update': args.b_update,
        'n_print': args.n_print
    }
    datasets = (tr_reader, cv_reader)
    train_model(model,
                datasets,
                args.n_epoch,
                device,
                cfg,
                loss_norm=args.loss_norm,
                grad_norm=args.grad_norm,
                fp16=args.fp16)