def load(args):
    print('loading model')
    device = torch.device(
        "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
    # Load M2M-100 model
    config = M2M100Config.from_pretrained("facebook/m2m100_418M")
    config.method = 1
    m2m = M2M100ForConditionalGeneration.from_pretrained(
        "facebook/m2m100_418M", config=config)
    tokenizer = M2M100Tokenizer.from_pretrained('facebook/m2m100_418M')
    # Build Fused Model and load parameters from local checkpoint
    model = FusedM2M(config, None, m2m)
    state_dict = torch.load(args.checkpoint)
    state_dict = {k: v
                  for k, v in state_dict.items()
                  if 'fuse' in k}  # load linear layer only
    model.load_state_dict(state_dict, strict=False)
    model = model.model  # Take the M2M100Model from M2M100ForConditionalGeneration

    model.to(device)
    if args.num_gpus > 1:
        model = torch.nn.DataParallel(model)
    model.eval()
    return model, tokenizer, device
Ejemplo n.º 2
0
def filter_none(example):
    return example["translation"][source_lang] is not None and example[
        "translation"][target_lang] is not None


# Preprocess data or load from local file
if args.load_local_dataset:
    tokenized_datasets = load_from_disk("data")
else:
    tokenized_datasets = raw_datasets.filter(filter_none).map(preprocess,
                                                              batched=True)
    tokenized_datasets.save_to_disk("data")

# Prepare models
config = M2M100Config.from_pretrained("facebook/m2m100_418M")
config.method = fuse_method
m2m = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M",
                                                     config=config)
fused_model = FusedM2M(config, bert, m2m)

# DEBUG: Check the weight of layers
# shared_weight = m2m.model.shared.weight.data.clone().detach()
# layer_1_weight = m2m.model.encoder.layers[0].fc1.weight.data.clone().detach()
# fuse_12_weight = m2m.model.encoder.layers[-1].fuse_layer.weight.data.clone().detach()

# Load state dict from local checkpoint
if checkpoint:
    state_dict = torch.load(f'{checkpoint}/pytorch_model.bin')
    state_dict = {k: v
                  for k, v in state_dict.items()