コード例 #1
0
ファイル: run_glue.py プロジェクト: mlpen/Nystromformer
########################### Load Model ###########################

data = DatasetProcessor(dataset_root_folder, dataset_config)
model = ModelWrapper(model_config)

checkpoint = torch.load(checkpoint_path, map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])
print("Model restored", checkpoint_path)

model.init_sen_class()

model = model.cuda()

model = nn.DataParallel(model, device_ids=device_ids)

optimizer = torch.optim.AdamW(model.parameters(),
                              lr=downsteam_task_config["learning_rate"],
                              betas=(0.9, 0.999),
                              eps=1e-6,
                              weight_decay=0.01)

########################### Train Model ###########################


def train():
    train = downsteam_task_config["train"]
    train_downsteam_data = BertDownsteamDatasetWrapper(
        data, downsteam_task_config["file_path"],
        downsteam_task_config["task"], train)
    train_downsteam_dataloader = torch.utils.data.DataLoader(
        train_downsteam_data,
コード例 #2
0
checkpoint_path = utils.get_last_checkpoint(checkpoint_dir)

device_ids = list(range(torch.cuda.device_count()))
print(f"GPU list: {device_ids}")

print(json.dumps([model_config, pretraining_task_config, dataset_config], indent = 4))

########################### Loading Model ###########################

data = DatasetProcessor(dataset_root_folder, dataset_config)
model = ModelWrapper(model_config)
print(model)

num_parameter = 0
for weight in model.parameters():
    print(weight.size())
    size = 1
    for d in weight.size():
        size *= d
    num_parameter += size
print(f"num_parameter: {num_parameter}")

model = model.cuda()
model = nn.DataParallel(model, device_ids = device_ids)

if "from_cp" in config and checkpoint_path is None:
    checkpoint = torch.load(os.path.join(curr_path, 'models', args.model, config["from_cp"]), map_location = 'cpu')

    cp_pos_encoding = checkpoint['model_state_dict']['model.embeddings.position_embeddings.weight']
    cp_max_seq_len = cp_pos_encoding.size(0)