Beispiel #1
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("-exp_dir")
    parser.add_argument("-dataPath",
                        default='',
                        type=str,
                        help="path of data files")
    parser.add_argument("-train_config")
    parser.add_argument("-data_config")
    parser.add_argument("-lr",
                        default=0.0001,
                        type=float,
                        help="Override the LR in the config")
    parser.add_argument("-batch_size",
                        default=32,
                        type=int,
                        help="Override the batch size in the config")
    parser.add_argument("-data_loader_threads",
                        default=0,
                        type=int,
                        help="number of workers for data loading")
    parser.add_argument("-max_grad_norm",
                        default=5,
                        type=float,
                        help="max_grad_norm for gradient clipping")
    parser.add_argument("-sweep_size",
                        default=200,
                        type=float,
                        help="process n hours of data per sweep (default:200)")
    parser.add_argument("-num_epochs",
                        default=1,
                        type=int,
                        help="number of training epochs (default:1)")
    parser.add_argument("-global_mvn",
                        default=False,
                        type=bool,
                        help="if apply global mean and variance normalization")
    parser.add_argument(
        "-resume_from_model",
        type=str,
        help="the model from which you want to resume training")
    parser.add_argument("-dropout", type=float, help="set the dropout ratio")
    parser.add_argument("-aneal_lr_epoch",
                        default=2,
                        type=int,
                        help="start to aneal the learning rate from this epoch"
                        )  # aneal -> anneal?
    parser.add_argument("-aneal_lr_ratio",
                        default=0.5,
                        type=float,
                        help="the ratio to aneal the learning rate")
    parser.add_argument('-p',
                        '--print-freq',
                        default=100,
                        type=int,
                        metavar='N',
                        help='print frequency (default: 100)')
    parser.add_argument('-hvd',
                        default=False,
                        type=bool,
                        help="whether to use horovod for training")

    args = parser.parse_args()

    with open(args.train_config) as f:
        config = yaml.safe_load(f)

    config["sweep_size"] = args.sweep_size
    with open(args.data_config) as f:
        data = yaml.safe_load(f)
        config["source_paths"] = [j for i, j in data['clean_source'].items()]
        if 'dir_noise' in data:
            config["dir_noise_paths"] = [
                j for i, j in data['dir_noise'].items()
            ]
        if 'rir' in data:
            config["rir_paths"] = [j for i, j in data['rir'].items()]

    config['data_path'] = args.dataPath

    print("Experiment starts with config {}".format(
        json.dumps(config, sort_keys=True, indent=4)))

    # Initialize Horovod
    if args.hvd:
        import horovod.torch as hvd
        hvd.init()
        th.cuda.set_device(hvd.local_rank())
        print("Run experiments with world size {}".format(hvd.size()))

    if not os.path.isdir(args.exp_dir):
        os.makedirs(args.exp_dir)

    trainset = SpeechDataset(config)
    train_dataloader = ChunkDataloader(trainset,
                                       batch_size=args.batch_size,
                                       distributed=args.multi_gpu,
                                       num_workers=args.data_loader_threads)

    if args.global_mvn:
        transform = GlobalMeanVarianceNormalization()
        print("Estimating global mean and variance of feature vectors...")
        transform.learn_mean_and_variance_from_train_loader(
            trainset, trainset.stream_idx_for_transform, n_sample_to_use=2000)
        trainset.transform = transform
        print("Global mean and variance transform trained successfully!")

        with open(args.exp_dir + "/transform.pkl", 'wb') as f:
            pickle.dump(transform, f, pickle.HIGHEST_PROTOCOL)

    print("Data loader set up successfully!")
    print("Number of minibatches: {}".format(len(train_dataloader)))

    # ceate model
    model_config = config["model_config"]
    lstm = LSTMStack(model_config["feat_dim"], model_config["hidden_size"],
                     model_config["num_layers"], model_config["dropout"], True)
    model = NnetAM(lstm, model_config["hidden_size"] * 2,
                   model_config["label_size"])

    # Start training
    th.backends.cudnn.enabled = True
    if th.cuda.is_available():
        model.cuda()

    # optimizer
    optimizer = th.optim.Adam(model.parameters(), lr=args.lr, amsgrad=True)

    if args.hvd:
        # Broadcast parameters and opterimizer state from rank 0 to all other processes.
        hvd.broadcast_parameters(model.state_dict(), root_rank=0)
        hvd.broadcast_optimizer_state(optimizer, root_rank=0)

        # Add Horovod Distributed Optimizer
        optimizer = hvd.DistributedOptimizer(
            optimizer, named_parameters=model.named_parameters())

    # criterion
    criterion = nn.CrossEntropyLoss(ignore_index=-100)

    start_epoch = 0
    if args.resume_from_model:

        assert os.path.isfile(args.resume_from_model
                              ), "ERROR: model file {} does not exit!".format(
                                  args.resume_from_model)

        checkpoint = th.load(args.resume_from_model)
        state_dict = checkpoint['model']
        start_epoch = checkpoint['epoch']
        model.load_state_dict(state_dict)
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> loaded checkpoint '{}' ".format(args.resume_from_model))

    model.train()
    for epoch in range(start_epoch, args.num_epochs):

        # aneal learning rate
        if epoch > args.aneal_lr_epoch:
            for param_group in optimizer.param_groups:
                param_group['lr'] *= args.aneal_lr_ratio

        run_train_epoch(model, optimizer, criterion, train_dataloader, epoch,
                        args)

        # save model
        if not args.hvd or hvd.rank() == 0:
            checkpoint = {}
            checkpoint['model'] = model.state_dict()
            checkpoint['optimizer'] = optimizer.state_dict()
            checkpoint['epoch'] = epoch
            output_file = args.exp_dir + '/model.' + str(epoch) + '.tar'
            th.save(checkpoint, output_file)
Beispiel #2
0
def main():
    """
    This script is mainly used for LibriCSS evaluation, which dump the loglikelihoods from the pretrain model
    released in the LibriCSS repo: https://github.com/chenzhuo1011/libri_css 
    """
    parser = argparse.ArgumentParser()                                                                                 
    parser.add_argument("-config")                                                                                     
    parser.add_argument("-model_path")                                                                                 
    parser.add_argument("-data_path", default='', type=str, help="path of data files")
    parser.add_argument("-prior_path", default=None, help="the path to load the final.occs file")
    parser.add_argument("-transform", help="feature transformation matrix or mvn statistics")
    parser.add_argument("-out_file", help="write out the log-probs to this file") 
    parser.add_argument("-batch_size", default=32, type=int, help="Override the batch size in the config")             
    parser.add_argument("-sweep_size", default=200, type=float, help="process n hours of data per sweep (default:60)")
    parser.add_argument("-frame_subsampling_factor", default=1, type=int, help="the factor to subsample the features") 
    parser.add_argument("-data_loader_threads", default=4, type=int, help="number of workers for data loading")
    parser.add_argument("-gpuid", default=0, type=int, help="GPU ID")

    args = parser.parse_args()

    with open(args.config) as f:
        config = yaml.safe_load(f)

    config["sweep_size"] = args.sweep_size

    config["source_paths"] = list()
    data_config = dict()

    data_config["type"] = "Eval"
    data_config["wav"] = args.data_path

    config["source_paths"].append(data_config)
    config["data_path"] = ""

    print("job starts with config {}".format(json.dumps(config, sort_keys=True, indent=4)))


    dataset = SpeechDataset(config)
    transform=None
    if args.transform is not None and os.path.isfile(args.transform):
        with open(args.transform, 'rb') as f:
            transform = pickle.load(f)
            dataset.transform = transform

    test_dataloader = SeqDataloader(dataset, 
                                    batch_size=args.batch_size, 
                                    test_only=True)

    print("Data loader set up successfully!")
    print("Number of minibatches: {}".format(len(test_dataloader)))

    # ceate model
    model_config = config["model_config"]
    net = LSTMStack(model_config["feat_dim"], model_config["hidden_size"], model_config["num_layers"], model_config["dropout"], True)
    model = NnetAM(net, model_config["hidden_size"]*2, model_config["label_size"])

    device = th.device("cuda:{}".format(args.gpuid) if th.cuda.is_available() else "cpu")
    th.cuda.set_device(device)
    model.cuda()

    assert os.path.isfile(args.model_path), "ERROR: model file {} does not exit!".format(args.model_path)

    checkpoint = th.load(args.model_path, map_location='cuda:0')                                            
    state_dict = checkpoint['model']
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        header = k[:7]
        name = k[7:] # remove 'module.' of dataparallel
        new_state_dict[name]=v
    if header == "module.":
        model.load_state_dict(new_state_dict) 
    else:
        model.load_state_dict(state_dict)
    print("=> loaded checkpoint '{}' ".format(args.model_path))                      

    log_prior = None
    if(args.prior_path): 
        prior = read_matrix(args.prior_path).numpy()
        log_prior = th.tensor(np.log(prior[0]/np.sum(prior[0])), dtype=th.float)

    model.eval()
    with th.no_grad():
        with MatrixWriter("ark:"+args.out_file) as llout:
            for i, data in enumerate(test_dataloader):
                feat = data["x"]
                num_frs = data["num_frs"]
                utt_ids = data["utt_ids"]
 
                x = feat.to(th.float32)
                if(args.frame_subsampling_factor > 1):
                    x = x.unfold(1, 1, args.frame_subsampling_factor).squeeze(-1)
                    num_frs = [int(i/args.frame_subsampling_factor) for i in num_frs]

                x = x.cuda()
                prediction = model(x)
                # save only unpadded part for each utt in batch                                         
                for j in range(len(num_frs)):                                                            
                    loglikes=prediction[j,:,:].data.cpu()                                                      
                    loglikes_j = loglikes[:num_frs[j],:]
                    if(log_prior is not None):
                        loglikes_j = loglikes_j - log_prior                                                   
                                                                                         
                    llout[utt_ids[j][0]] = loglikes_j

                print("Process batch [{}/{}]".format(i+1, len(test_dataloader)))
Beispiel #3
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("-config")
    parser.add_argument("-data", help="data yaml file")
    parser.add_argument("-data_path",
                        default='',
                        type=str,
                        help="path of data files")
    parser.add_argument("-seed_model", help="the seed nerual network model")
    parser.add_argument("-exp_dir", help="the directory to save the outputs")
    parser.add_argument("-transform",
                        help="feature transformation matrix or mvn statistics")
    parser.add_argument("-criterion",
                        type=str,
                        choices=["mmi", "mpfe", "smbr"],
                        help="set the sequence training crtierion")
    parser.add_argument(
        "-trans_model",
        help="the HMM transistion model, used for lattice generation")
    parser.add_argument(
        "-prior_path",
        help="the prior for decoder, usually named as final.occs in kaldi setup"
    )
    parser.add_argument(
        "-den_dir",
        help="the decoding graph directory to find HCLG and words.txt files")
    parser.add_argument("-lr", type=float, help="set the learning rate")
    parser.add_argument("-ce_ratio",
                        default=0.1,
                        type=float,
                        help="the ratio for ce regularization")
    parser.add_argument("-momentum",
                        default=0,
                        type=float,
                        help="set the momentum")
    parser.add_argument("-batch_size",
                        default=32,
                        type=int,
                        help="Override the batch size in the config")
    parser.add_argument("-data_loader_threads",
                        default=0,
                        type=int,
                        help="number of workers for data loading")
    parser.add_argument("-max_grad_norm",
                        default=5,
                        type=float,
                        help="max_grad_norm for gradient clipping")
    parser.add_argument("-sweep_size",
                        default=100,
                        type=float,
                        help="process n hours of data per sweep (default:60)")
    parser.add_argument("-num_epochs",
                        default=1,
                        type=int,
                        help="number of training epochs (default:1)")
    parser.add_argument('-print_freq',
                        default=10,
                        type=int,
                        metavar='N',
                        help='print frequency (default: 10)')
    parser.add_argument('-save_freq',
                        default=1000,
                        type=int,
                        metavar='N',
                        help='save model frequency (default: 1000)')

    args = parser.parse_args()

    with open(args.config) as f:
        config = yaml.safe_load(f)

    config['data_path'] = args.data_path

    config["sweep_size"] = args.sweep_size

    print("pytorch version:{}".format(th.__version__))

    with open(args.data) as f:
        data = yaml.safe_load(f)
        config["source_paths"] = [j for i, j in data['clean_source'].items()]

    print("Experiment starts with config {}".format(
        json.dumps(config, sort_keys=True, indent=4)))

    # Initialize Horovod
    hvd.init()

    th.cuda.set_device(hvd.local_rank())

    print("Run experiments with world size {}".format(hvd.size()))

    dataset = SpeechDataset(config)
    transform = None
    if args.transform is not None and os.path.isfile(args.transform):
        with open(args.transform, 'rb') as f:
            transform = pickle.load(f)
            dataset.transform = transform

    train_dataloader = SeqDataloader(dataset,
                                     batch_size=args.batch_size,
                                     num_workers=args.data_loader_threads,
                                     distributed=True,
                                     test_only=False)

    print("Data loader set up successfully!")
    print("Number of minibatches: {}".format(len(train_dataloader)))

    if not os.path.isdir(args.exp_dir):
        os.makedirs(args.exp_dir)

    # ceate model
    model_config = config["model_config"]
    lstm = LSTMStack(model_config["feat_dim"], model_config["hidden_size"],
                     model_config["num_layers"], model_config["dropout"], True)
    model = NnetAM(lstm, model_config["hidden_size"] * 2,
                   model_config["label_size"])

    model.cuda()

    # setup the optimizer
    optimizer = th.optim.SGD(model.parameters(),
                             lr=args.lr,
                             momentum=args.momentum)

    # Broadcast parameters and opterimizer state from rank 0 to all other processes.
    hvd.broadcast_parameters(model.state_dict(), root_rank=0)
    hvd.broadcast_optimizer_state(optimizer, root_rank=0)

    # Add Horovod Distributed Optimizer
    optimizer = hvd.DistributedOptimizer(
        optimizer, named_parameters=model.named_parameters())

    if os.path.isfile(args.seed_model):
        checkpoint = th.load(args.seed_model)
        state_dict = checkpoint['model']
        from collections import OrderedDict
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = k[7:]  # remove 'module.' of dataparallel
            new_state_dict[name] = v
        model.load_state_dict(new_state_dict)
        print("=> loaded checkpoint '{}' ".format(args.seed_model))
    else:
        sys.stderr.write('ERROR: The model file %s does not exist!\n' %
                         (model_file))
        sys.exit(0)

    HCLG = args.den_dir + "/HCLG.fst"
    words_txt = args.den_dir + "/words.txt"
    silence_phones = args.den_dir + "/phones/silence.csl"

    if not os.path.isfile(HCLG):
        sys.stderr.write('ERROR: The HCLG file %s does not exist!\n' % (HCLG))
        sys.exit(0)

    if not os.path.isfile(words_txt):
        sys.stderr.write('ERROR: The words.txt file %s does not exist!\n' %
                         (words_txt))
        sys.exit(0)

    if not os.path.isfile(silence_phones):
        sys.stderr.write('ERROR: The silence phone file %s does not exist!\n' %
                         (silence_phones))
        sys.exit(0)
    with open(silence_phones) as f:
        silence_ids = [int(i) for i in f.readline().strip().split(':')]
        f.close()

    if os.path.isfile(args.trans_model):
        trans_model = kaldi_hmm.TransitionModel()
        with kaldi_util.io.xopen(args.trans_model) as ki:
            trans_model.read(ki.stream(), ki.binary)
    else:
        sys.stderr.write('ERROR: The trans_model %s does not exist!\n' %
                         (args.trans_model))
        sys.exit(0)

    # now we can setup the decoder
    decoder_opts = LatticeFasterDecoderOptions()
    decoder_opts.beam = config["decoder_config"]["beam"]
    decoder_opts.lattice_beam = config["decoder_config"]["lattice_beam"]
    decoder_opts.max_active = config["decoder_config"]["max_active"]
    acoustic_scale = config["decoder_config"]["acoustic_scale"]
    decoder_opts.determinize_lattice = False  #To produce raw state-level lattice instead of compact lattice
    asr_decoder = MappedLatticeFasterRecognizer.from_files(
        args.trans_model,
        HCLG,
        words_txt,
        acoustic_scale=acoustic_scale,
        decoder_opts=decoder_opts)

    prior = kaldi_util.io.read_matrix(args.prior_path).numpy()
    log_prior = th.tensor(np.log(prior[0] / np.sum(prior[0])), dtype=th.float)

    model.train()
    for epoch in range(args.num_epochs):

        run_train_epoch(model, optimizer, log_prior.cuda(), train_dataloader,
                        epoch, asr_decoder, trans_model, silence_ids, args)

        # save model
        if hvd.rank() == 0:
            checkpoint = {}
            checkpoint['model'] = model.state_dict()
            checkpoint['optimizer'] = optimizer.state_dict()
            checkpoint['epoch'] = epoch
            output_file = args.exp_dir + '/model.se.' + str(epoch) + '.tar'
            th.save(checkpoint, output_file)
Beispiel #4
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("-config")
    parser.add_argument("-data", help="data yaml file")
    parser.add_argument("-dataPath",
                        default='',
                        type=str,
                        help="path of data files")
    parser.add_argument("-seed_model",
                        default='',
                        help="the seed nerual network model")
    parser.add_argument("-exp_dir", help="the directory to save the outputs")
    parser.add_argument("-transform",
                        help="feature transformation matrix or mvn statistics")
    parser.add_argument(
        "-ali_dir",
        help="the directory to load trans_model and tree used for alignments")
    parser.add_argument("-lang_dir",
                        help="the lexicon directory to load L.fst")
    parser.add_argument(
        "-chain_dir",
        help=
        "the directory to load trans_model, tree and den.fst for chain model")
    parser.add_argument("-lr", type=float, help="set the base learning rate")
    parser.add_argument(
        "-warmup_steps",
        default=4000,
        type=int,
        help="the number of warmup steps to adjust the learning rate")
    parser.add_argument("-xent_regularize",
                        default=0,
                        type=float,
                        help="cross-entropy regularization weight")
    parser.add_argument("-momentum",
                        default=0,
                        type=float,
                        help="set the momentum")
    parser.add_argument("-weight_decay",
                        default=1e-4,
                        type=float,
                        help="set the L2 regularization weight")
    parser.add_argument("-batch_size",
                        default=32,
                        type=int,
                        help="Override the batch size in the config")
    parser.add_argument("-data_loader_threads",
                        default=0,
                        type=int,
                        help="number of workers for data loading")
    parser.add_argument("-max_grad_norm",
                        default=5,
                        type=float,
                        help="max_grad_norm for gradient clipping")
    parser.add_argument("-sweep_size",
                        default=100,
                        type=float,
                        help="process n hours of data per sweep (default:100)")
    parser.add_argument("-num_epochs",
                        default=1,
                        type=int,
                        help="number of training epochs (default:1)")
    parser.add_argument(
        "-anneal_lr_epoch",
        default=2,
        type=int,
        help="start to anneal the learning rate from this epoch")
    parser.add_argument("-anneal_lr_ratio",
                        default=0.5,
                        type=float,
                        help="the ratio to anneal the learning rate ratio")
    parser.add_argument('-print_freq',
                        default=10,
                        type=int,
                        metavar='N',
                        help='print frequency (default: 10)')
    parser.add_argument('-save_freq',
                        default=1000,
                        type=int,
                        metavar='N',
                        help='save model frequency (default: 1000)')

    args = parser.parse_args()

    with open(args.config) as f:
        config = yaml.safe_load(f)

    config["sweep_size"] = args.sweep_size

    print("pytorch version:{}".format(th.__version__))

    with open(args.data) as f:
        data = yaml.safe_load(f)
        config["source_paths"] = [j for i, j in data['clean_source'].items()]
        if 'dir_noise' in data:
            config["dir_noise_paths"] = [
                j for i, j in data['dir_noise'].items()
            ]
        if 'rir' in data:
            config["rir_paths"] = [j for i, j in data['rir'].items()]
    config['data_path'] = args.dataPath

    print("Experiment starts with config {}".format(
        json.dumps(config, sort_keys=True, indent=4)))

    # Initialize Horovod
    hvd.init()

    th.cuda.set_device(hvd.local_rank())

    print("Run experiments with world size {}".format(hvd.size()))

    dataset = SpeechDataset(config)
    transform = None
    if args.transform is not None and os.path.isfile(args.transform):
        with open(args.transform, 'rb') as f:
            transform = pickle.load(f)
            dataset.transform = transform

    train_dataloader = SeqDataloader(dataset,
                                     batch_size=args.batch_size,
                                     num_workers=args.data_loader_threads,
                                     distributed=True,
                                     test_only=False)

    print("Data loader set up successfully!")
    print("Number of minibatches: {}".format(len(train_dataloader)))

    if not os.path.isdir(args.exp_dir):
        os.makedirs(args.exp_dir)

    # ceate model
    model_config = config["model_config"]
    model = lstm.LSTMAM(model_config["feat_dim"], model_config["label_size"],
                        model_config["hidden_size"],
                        model_config["num_layers"], model_config["dropout"],
                        True)

    model.cuda()

    # setup the optimizer
    optimizer = th.optim.Adam(model.parameters(), lr=args.lr, amsgrad=True)

    # Broadcast parameters and opterimizer state from rank 0 to all other processes.
    hvd.broadcast_parameters(model.state_dict(), root_rank=0)
    hvd.broadcast_optimizer_state(optimizer, root_rank=0)

    # Add Horovod Distributed Optimizer
    optimizer = hvd.DistributedOptimizer(
        optimizer, named_parameters=model.named_parameters())

    if os.path.isfile(args.seed_model):
        checkpoint = th.load(args.seed_model)
        state_dict = checkpoint['model']
        from collections import OrderedDict
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            header = k[:7]
            name = k[7:]  # remove 'module.' of dataparallel
            new_state_dict[name] = v
        if header == "module.":
            model.load_state_dict(new_state_dict)
        else:
            model.load_state_dict(state_dict)
        print("=> loaded checkpoint '{}' ".format(args.seed_model))

    ali_model = args.ali_dir + "/final.mdl"
    ali_tree = args.ali_dir + "/tree"
    L_fst = args.lang_dir + "/L.fst"
    disambig = args.lang_dir + "/phones/disambig.int"

    den_fst = kaldi_fst.StdVectorFst.read(args.chain_dir + "/den.fst")
    chain_model_path = args.chain_dir + "/0.trans_mdl"
    chain_tree_path = args.chain_dir + "/tree"

    if os.path.isfile(chain_model_path):
        chain_trans_model = kaldi_hmm.TransitionModel()
        with kaldi_util.io.xopen(chain_model_path) as ki:
            chain_trans_model.read(ki.stream(), ki.binary)
    else:
        sys.stderr.write('ERROR: The trans_model %s does not exist!\n' %
                         (trans_model))
        sys.exit(0)

    chain_tree = kaldi_tree.ContextDependency()
    with kaldi_util.io.xopen(chain_tree_path) as ki:
        chain_tree.read(ki.stream(), ki.binary)

    # chain supervision options
    supervision_opts = kaldi_chain.SupervisionOptions()
    supervision_opts.convert_to_pdfs = True
    supervision_opts.frame_subsampling_factor = 3
    supervision_opts.left_tolerance = 5
    supervision_opts.right_tolerance = 5

    # chain training options
    chain_opts = kaldi_chain.ChainTrainingOptions()
    chain_opts.leaky_hmm_coefficient = 1e-4
    chain_opts.xent_regularize = args.xent_regularize

    # setup the aligner
    aligner = kaldi_align.MappedAligner.from_files(ali_model,
                                                   ali_tree,
                                                   L_fst,
                                                   None,
                                                   disambig,
                                                   None,
                                                   beam=10,
                                                   transition_scale=1.0,
                                                   self_loop_scale=0.1,
                                                   acoustic_scale=0.1)
    den_graph = kaldi_chain.DenominatorGraph(den_fst,
                                             model_config["label_size"])

    #encoder_layer = nn.TransformerEncoderLayer(512, 8)
    #print(encoder_layer)

    model.train()
    for epoch in range(args.num_epochs):

        # anneal learning rate
        if epoch > args.anneal_lr_epoch:
            for param_group in optimizer.param_groups:
                param_group['lr'] *= args.anneal_lr_ratio

        run_train_epoch(model, optimizer, train_dataloader, epoch,
                        chain_trans_model, chain_tree, supervision_opts,
                        aligner, den_graph, chain_opts, args)

        # save model
        if hvd.rank() == 0:
            checkpoint = {}
            checkpoint['model'] = model.state_dict()
            checkpoint['optimizer'] = optimizer.state_dict()
            checkpoint['epoch'] = epoch
            output_file = args.exp_dir + '/chain.model.' + str(epoch) + '.tar'
            th.save(checkpoint, output_file)