コード例 #1
0
    # N_EPOCHS = 10
    # CLIP = 1

    best_valid_loss = float('inf')

    for epoch in range(args.epoches):

        start_time = time.time()

        train_loss = train(transformer_model, optimizer, criterion,
                           args.gradient_clip, device)
        valid_loss = evaluate(transformer_model, criterion, device)

        end_time = time.time()

        epoch_mins, epoch_secs = epoch_time(start_time, end_time)

        if train_loss <= best_valid_loss:
            best_valid_loss = train_loss
            torch.save(
                transformer_model.state_dict(),
                './models/transformer/transformer-model_{}.pt'.format(epoch +
                                                                      1))

        logger.info('Epoch: {:02} | Time: {}m {}s'.format(
            epoch + 1, epoch_mins, epoch_secs))
        logger.info('\tTrain Loss: {:.3f} | Train PPL: {:7.3f}'.format(
            train_loss, math.exp(train_loss)))
        logger.info('\t Val. Loss: {:.3f} |  Val. PPL: {:7.3f}'.format(
            valid_loss, math.exp(valid_loss)))
コード例 #2
0
def main():
    ''' Main function '''
    parser = argparse.ArgumentParser()
    parser.add_argument('-config', type=str, default='config/rnnt.yaml')
    parser.add_argument('-load_model', type=str, default=None)
    parser.add_argument('-fp16_allreduce',
                        action='store_true',
                        default=False,
                        help='use fp16 compression during allreduce')
    parser.add_argument('-batches_per_allreduce',
                        type=int,
                        default=1,
                        help='number of batches processed locally before '
                        'executing allreduce across workers; it multiplies '
                        'total batch size.')
    parser.add_argument(
        '-num_wokers',
        type=int,
        default=0,
        help='how many subprocesses to use for data loading. '
        '0 means that the data will be loaded in the main process')
    parser.add_argument('-log', type=str, default='train.log')
    opt = parser.parse_args()

    configfile = open(opt.config)
    config = AttrDict(yaml.load(configfile))

    global global_step
    global_step = 0

    if hvd.rank() == 0:
        exp_name = config.data.name
        if not os.path.isdir(exp_name):
            os.mkdir(exp_name)
        logger = init_logger(exp_name + '/' + opt.log)
    else:
        logger = None

    if torch.cuda.is_available():
        torch.cuda.set_device(hvd.local_rank())
        torch.cuda.manual_seed(config.training.seed)
        torch.backends.cudnn.deterministic = True
    else:
        raise NotImplementedError

    #========= Build DataLoader =========#
    train_dataset = AudioDateset(config.data, 'train')
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset, num_replicas=hvd.size(), rank=hvd.rank())
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.data.train.batch_size,
        sampler=train_sampler)

    assert train_dataset.vocab_size == config.model.vocab_size

    #========= Build A Model Or Load Pre-trained Model=========#
    model = Transformer(config.model)

    if hvd.rank() == 0:
        n_params, enc_params, dec_params = count_parameters(model)
        logger.info('# the number of parameters in the whole model: %d' %
                    n_params)
        logger.info('# the number of parameters in encoder: %d' % enc_params)
        logger.info('# the number of parameters in decoder: %d' % dec_params)

    model.cuda()

    # define an optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=lr,
                                 betas=(0.9, 0.98),
                                 eps=1e-9)

    # Horovod: (optional) compression algorithm.
    compression = hvd.Compression.fp16 if opt.fp16_allreduce else hvd.Compression.none

    optimizer = hvd.DistributedOptimizer(
        optimizer,
        named_parameters=model.named_parameters(),
        compression=compression)

    # load pretrain model
    if opt.load_model is not None and hvd.rank() == 0:
        checkpoint = torch.load(opt.load_model)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        logger.info('Load pretrainded Model and previous Optimizer!')
    elif hvd.rank() == 0:
        init_parameters(model)
        logger.info('Initialized all parameters!')

    # Horovod: broadcast parameters & optimizer state.
    hvd.broadcast_parameters(model.state_dict(), root_rank=0)
    hvd.broadcast_optimizer_state(optimizer, root_rank=0)

    # define loss function
    crit = nn.CrossEntropyLoss(ignore_index=0)

    # create a visualizer
    if config.training.visualization and hvd.rank() == 0:
        visualizer = SummaryWriter(exp_name + '/log')
        logger.info('Created a visualizer.')
    else:
        visualizer = None

    for epoch in range(config.training.epoches):
        train(epoch, model, crit, optimizer, train_loader, train_sampler,
              logger, visualizer, config)

        if hvd.rank() == 0:
            save_model(epoch, model, optimizer, config, logger)

    if hvd.rank() == 0:
        logger.info('Traing Process Finished')