def train(args, logger, writer): # set device if args.gpu_ids is None: device = torch.device("cpu") else: if isinstance(args.gpu_ids, int): args.gpu_ids = [args.gpu_ids] device = torch.device("cuda:%d" % args.gpu_ids[0]) torch.cuda.set_device(device) args.num_rels = 2 # for binary classification if args.pretrained_model_path: # load pretrained model config = load_config( os.path.join(args.pretrained_model_path, "BMGFModel.config")) for by in ["accf1", "f1", "accuracy", "loss"]: best_epochs = get_best_epochs(os.path.join( args.pretrained_model_path, "BMGFModel.log"), by=by) if len(best_epochs) > 0: break logger.info("retrieve the best epochs for BMGFModel: %s" % (best_epochs)) if len(best_epochs) > 0: model = BMGFModel(**(config._asdict())) if "test" in best_epochs: model.load_state_dict( torch.load(os.path.join( args.pretrained_model_path, "epoch%d.pt" % (best_epochs["test"])), map_location=device)) elif "valid" in best_epochs: model.load_state_dict( torch.load(os.path.join( args.pretrained_model_path, "epoch%d.pt" % (best_epochs["valid"])), map_location=device)) else: model.load_state_dict( torch.load(os.path.join( args.pretrained_model_path, "epoch%d.pt" % (best_epochs["train"])), map_location=device)) if config.dropout != args.dropout: change_dropout_rate(model, args.dropout) else: raise ValueError("Error: cannot load BMGFModel from %s." % (args.pretrained_model_path)) else: # build model model = BMGFModel(**vars(args)) model.set_finetune(args.finetune) if args.gpu_ids and len(args.gpu_ids) > 1: model = nn.DataParallel(model, device_ids=args.gpu_ids) model = model.to(device) logger.info(model) logger.info("num of trainable parameters: %d" % (sum(p.numel() for p in model.parameters() if p.requires_grad))) # load data datasets = OrderedDict({ "train": Dataset().load_pt(args.train_dataset_path), "valid": Dataset().load_pt(args.valid_dataset_path), "test": Dataset().load_pt(args.test_dataset_path) }) if args.explicit_dataset_path != "": explicit_dataset = Dataset().load_pt(args.explicit_dataset_path) datasets["train"].data.extend(explicit_dataset.data) del explicit_dataset logger.info("train:valid:test = %d:%d:%d" % (len( datasets["train"]), len(datasets["valid"]), len(datasets["test"]))) rel_map = defaultdict(int) for r in args.relations: for k in Dataset.rel_map_4.keys(): if k.startswith(r): rel_map[k] = 1 assert len(rel_map) > 0 if args.encoder == "roberta": pad_id = 1 else: pad_id = 0 data_loaders = OrderedDict() batchify = partial(Dataset.batchify, rel_map=rel_map, min_arg=args.min_arg, max_arg=args.max_arg, pad_id=pad_id) for data_type in datasets: sampler = Sampler(datasets[data_type], group_by=["arg1", "arg2"], batch_size=args.batch_size, shuffle=data_type == "train", drop_last=False) data_loaders[data_type] = data.DataLoader( datasets[data_type], batch_sampler=sampler, collate_fn=batchify, pin_memory=data_type == "train") # optimizer and losses optimizer = Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) optimizer.zero_grad() best_losses = {dataset: INF for dataset in datasets} best_loss_epochs = {dataset: -1 for dataset in datasets} best_accs = {dataset: _INF for dataset in datasets} best_acc_epochs = {dataset: -1 for dataset in datasets} best_f1s = {dataset: _INF for dataset in datasets} best_f1_epochs = {dataset: -1 for dataset in datasets} best_accf1s = {dataset: _INF for dataset in datasets} best_accf1_epochs = {dataset: -1 for dataset in datasets} for epoch in range(args.epochs): for data_type, data_loader in data_loaders.items(): if data_type == "train": mean_loss, results = train_epoch(args, logger, writer, model, optimizer, data_type, data_loader, device, epoch) else: mean_loss, results = eval_epoch(args, logger, writer, model, data_type, data_loader, device, epoch) save_results( results, os.path.join(args.save_model_dir, "%s_results%d.json" % (data_type, epoch))) if mean_loss <= best_losses[data_type]: best_losses[data_type] = mean_loss best_loss_epochs[data_type] = epoch logger.info( "data_type: {:<5s}\tbest pdtb-loss: {:.4f} (epoch: {:0>3d})" .format(data_type, best_losses[data_type], best_loss_epochs[data_type])) if args.save_best == "loss": if args.gpu_ids and len(args.gpu_ids) > 1: torch.save(model.module.state_dict(), os.path.join(args.save_model_dir, "%s_best.pt" % (data_type)), _use_new_zipfile_serialization=False) else: torch.save(model.state_dict(), os.path.join(args.save_model_dir, "%s_best.pt" % (data_type)), _use_new_zipfile_serialization=False) if results["evaluation"]["accuracy"]["overall"] >= best_accs[ data_type]: best_accs[data_type] = results["evaluation"]["accuracy"][ "overall"] best_acc_epochs[data_type] = epoch logger.info( "data_type: {:<5s}\tbest pdtb-accuracy: {:.4f} (epoch: {:0>3d})" .format(data_type, best_accs[data_type], best_acc_epochs[data_type])) if args.save_best == "acc": if args.gpu_ids and len(args.gpu_ids) > 1: torch.save(model.module.state_dict(), os.path.join(args.save_model_dir, "%s_best.pt" % (data_type)), _use_new_zipfile_serialization=False) else: torch.save(model.state_dict(), os.path.join(args.save_model_dir, "%s_best.pt" % (data_type)), _use_new_zipfile_serialization=False) if results["evaluation"]["precision_recall_f1"]["overall"][ -1] >= best_f1s[data_type]: best_f1s[data_type] = results["evaluation"][ "precision_recall_f1"]["overall"][-1] best_f1_epochs[data_type] = epoch logger.info( "data_type: {:<5s}\tbest pdtb-f1: {:.4f} (epoch: {:0>3d})". format(data_type, best_f1s[data_type], best_f1_epochs[data_type])) if args.save_best == "f1": if args.gpu_ids and len(args.gpu_ids) > 1: torch.save(model.module.state_dict(), os.path.join(args.save_model_dir, "%s_best.pt" % (data_type)), _use_new_zipfile_serialization=False) else: torch.save(model.state_dict(), os.path.join(args.save_model_dir, "%s_best.pt" % (data_type)), _use_new_zipfile_serialization=False) if results["evaluation"]["accuracy"]["overall"] + results[ "evaluation"]["precision_recall_f1"]["overall"][ -1] >= best_accf1s[data_type]: best_accf1s[data_type] = results["evaluation"]["accuracy"][ "overall"] + results["evaluation"]["precision_recall_f1"][ "overall"][-1] best_accf1_epochs[data_type] = epoch logger.info( "data_type: {:<5s}\tbest pdtb-accf1: {:.4f} (epoch: {:0>3d})" .format(data_type, best_accf1s[data_type], best_accf1_epochs[data_type])) if args.save_best == "accf1": if args.gpu_ids and len(args.gpu_ids) > 1: torch.save(model.module.state_dict(), os.path.join(args.save_model_dir, "%s_best.pt" % (data_type)), _use_new_zipfile_serialization=False) else: torch.save(model.state_dict(), os.path.join(args.save_model_dir, "%s_best.pt" % (data_type)), _use_new_zipfile_serialization=False) for data_type in data_loaders: logger.info( "data_type: {:<5s}\tbest pdtb-loss: {:.4f} (epoch: {:0>3d})". format(data_type, best_losses[data_type], best_loss_epochs[data_type])) logger.info( "data_type: {:<5s}\tbest pdtb-accuracy: {:.4f} (epoch: {:0>3d})". format(data_type, best_accs[data_type], best_acc_epochs[data_type])) logger.info( "data_type: {:<5s}\tbest pdtb-f1: {:.4f} (epoch: {:0>3d})".format( data_type, best_f1s[data_type], best_f1_epochs[data_type])) logger.info( "data_type: {:<5s}\tbest pdtb-accf1: {:.4f} (epoch: {:0>3d})". format(data_type, best_accf1s[data_type], best_accf1_epochs[data_type]))
]): logger.info("loading data from pt...") for data_type in data_loaders: if finetune_config["model"] in ["RGCN", "RGIN", "RGIN"]: dataset = GraphAdjDataset(list()) dataset.load( os.path.join(finetune_config["save_data_dir"], "%s_dgl_dataset.pt" % (data_type))) if data_type == "train": np.random.shuffle(dataset.data) dataset.data = dataset.data[:math.ceil( len(dataset.data) * finetune_config["train_ratio"])] sampler = Sampler(dataset, group_by=["graph", "pattern"], batch_size=finetune_config["batch_size"], shuffle=data_type == "train", drop_last=False) data_loader = DataLoader(dataset, batch_sampler=sampler, collate_fn=GraphAdjDataset.batchify, pin_memory=data_type == "train") else: dataset = EdgeSeqDataset(list()) dataset.load( os.path.join(finetune_config["save_data_dir"], "%s_edgeseq_dataset.pt" % (data_type))) if data_type == "train": np.random.shuffle(dataset.data) dataset.data = dataset.data[:math.ceil( len(dataset.data) * finetune_config["train_ratio"])]