예제 #1
0
def _load_pytorch_transformer_model(device, dynamic_axes=False, legacy_api=False):
    # Loads external Pytorch TransformerModel into utils
    pytorch_transformer_path = os.path.join('..', '..', '..', 'samples', 'python', 'pytorch_transformer')
    pt_model_path = os.path.join(pytorch_transformer_path, 'pt_model.py')
    pt_model = _utils.import_module_from_file(pt_model_path)
    ort_utils_path = os.path.join(pytorch_transformer_path, 'ort_utils.py')
    ort_utils = _utils.import_module_from_file(ort_utils_path)
    utils_path = os.path.join(pytorch_transformer_path, 'utils.py')
    utils = _utils.import_module_from_file(utils_path)

    # Modeling
    model = pt_model.TransformerModel(28785, 200, 2, 200, 2, 0.2).to(device)
    my_loss = ort_utils.my_loss
    if legacy_api:
        if dynamic_axes:
            model_desc = ort_utils.legacy_transformer_model_description_dynamic_axes()
        else:
            model_desc = ort_utils.legacy_transformer_model_description()
    else:
        if dynamic_axes:
            model_desc = ort_utils.transformer_model_description_dynamic_axes()
        else:
            model_desc = ort_utils.transformer_model_description()


    # Preparing data
    train_data, val_data, test_data = utils.prepare_data(device, 20, 20)
    return model, model_desc, my_loss, utils.get_batch, train_data, val_data, test_data
예제 #2
0
def train_ort_model(epoch = 1):
    device = "cuda"
    ntokens=28785
    bptt = 35
    batch_size = 20
    initial_lr = 0.001
    opts = {'device' : {'id' : device}}
    
    train_data, val_data, test_data = prepare_data(device, 20, 20)
    pt_model_path = os.path.join('pt_model.py')
    pt_model = _utils.import_module_from_file(pt_model_path)
    model = pt_model.TransformerModel(28785, 200, 2, 200, 2, 0.2).to(device)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=initial_lr)

    model.train()
    total_loss = 0.
    start_time = time.time()
    for batch, i in enumerate(range(0, train_data.size(0) - 35, bptt)):
        data, targets = get_batch(train_data, i)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output.view(-1, ntokens), targets)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()
        total_loss += loss.item()
        
        log_interval = 200
        if batch % log_interval == 0 and batch > 0:
            cur_loss = total_loss / log_interval
            elapsed = time.time() - start_time
            print('| {} | epoch {:3d} | {:5d}/{:5d} batches | '
                  'lr {:02.3f} | ms/batch {:5.2f} | '
                  'loss {:5.2f} | ppl {:8.2f}'.format(
                    device, epoch, batch, len(train_data) // bptt, initial_lr,
                    elapsed * 1000 / log_interval,
                    cur_loss, math.exp(cur_loss)))
            total_loss = 0
            start_time = time.time()
예제 #3
0
def train_ort_model(epoch=1):
    device = "cuda"
    ntokens=28785
    bptt = 35
    batch_size = 20
    initial_lr = 0.001
    
    train_data, val_data, test_data = prepare_data(device, 20, 20)
    pt_model_path = os.path.join('pt_model.py')
    pt_model = _utils.import_module_from_file(pt_model_path)
    model = pt_model.TransformerModel(28785, 200, 2, 200, 2, 0.2).to(device)
    
    model_desc = {'inputs':  [('input1', [bptt, batch_size]),
                              ('label', [bptt * batch_size])],
                  'outputs': [('loss', [], True),
                              ('predictions', [bptt, batch_size, ntokens])]}

    opts = orttrainer.ORTTrainerOptions({'device' : {'id' : device}})
    optim_config = optim.SGDConfig(lr=initial_lr)
    trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts)

    total_loss = 0.
    start_time = time.time()
    for batch, i in enumerate(range(0, train_data.size(0) - 35, bptt)):
        data, targets = get_batch(train_data, i)
        output = trainer.train_step(data, targets)
        total_loss += output[0].item()
        
        log_interval = 200
        if batch % log_interval == 0 and batch > 0:
            cur_loss = total_loss / log_interval
            elapsed = time.time() - start_time
            print('| {} | epoch {:3d} | {:5d}/{:5d} batches | '
                  'lr {:02.3f} | ms/batch {:5.2f} | '
                  'loss {:5.2f} | ppl {:8.2f}'.format(
                    device, epoch, batch, len(train_data) // bptt, initial_lr,
                    elapsed * 1000 / log_interval,
                    cur_loss, math.exp(cur_loss)))
            total_loss = 0
            start_time = time.time()