示例#1
0
# feature N*n*f
# labels N*n*c
# Load data
# vertex: vertex id in global network N*n

if args.model == "pscn":
    influence_dataset = PatchySanDataSet(args.file_dir,
                                         args.dim,
                                         args.seed,
                                         args.shuffle,
                                         args.model,
                                         sequence_size=args.sequence_size,
                                         stride=1,
                                         neighbor_size=args.neighbor_size)
else:
    influence_dataset = InfluenceDataSet(args.file_dir, args.dim, args.seed,
                                         args.shuffle, args.model)

N = len(influence_dataset)
n_classes = 2
class_weight = influence_dataset.get_class_weight() \
        if args.class_weight_balanced else torch.ones(n_classes)
logger.info("class_weight=%.2f:%.2f", class_weight[0], class_weight[1])

feature_dim = influence_dataset.get_feature_dimension()
n_units = [feature_dim
           ] + [int(x)
                for x in args.hidden_units.strip().split(",")] + [n_classes]
logger.info("feature dimension=%d", feature_dim)
logger.info("number of classes=%d", n_classes)

train_start,  valid_start, test_start = \
示例#2
0
文件: train.py 项目: MH-0/RPGAE
def train_evaluate_model(data_path, model_name, model_type, hidden_units,
                         instance_normalization, class_weight_balanced,
                         use_vertex_feature):
    print("EXPERIMENT START------", model_name, model_type, hidden_units)

    logger = logging.getLogger(__name__)
    logging.basicConfig(level=logging.INFO,
                        format='%(asctime)s %(message)s')  # include timestamp

    # Training settings
    parser = argparse.ArgumentParser()
    parser.add_argument('--tensorboard-log',
                        type=str,
                        default='',
                        help="name of this run")
    parser.add_argument('--model',
                        type=str,
                        default=model_name,
                        help="models used")
    parser.add_argument('--model_type',
                        type=str,
                        default=model_type,
                        help="model type used")
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='Disables CUDA training.')
    parser.add_argument('--seed', type=int, default=42, help='Random seed.')
    parser.add_argument('--epochs',
                        type=int,
                        default=500,
                        help='Number of epochs to train.')
    parser.add_argument('--lr',
                        type=float,
                        default=0.1,
                        help='Initial learning rate.')
    parser.add_argument('--weight-decay',
                        type=float,
                        default=5e-4,
                        help='Weight decay (L2 loss on parameters).')
    parser.add_argument('--dropout',
                        type=float,
                        default=0.2,
                        help='Dropout rate (1 - keep probability).')
    parser.add_argument(
        '--hidden-units',
        type=str,
        default=hidden_units,
        help="Hidden units in each hidden layer, splitted with comma")
    parser.add_argument('--heads',
                        type=str,
                        default="1,1,1",
                        help="Heads in each layer, splitted with comma")
    parser.add_argument('--batch', type=int, default=1024, help="Batch size")
    parser.add_argument('--dim',
                        type=int,
                        default=64,
                        help="Embedding dimension")
    parser.add_argument('--check-point',
                        type=int,
                        default=10,
                        help="Eheck point")
    parser.add_argument('--instance-normalization',
                        action='store_true',
                        default=instance_normalization,
                        help="Enable instance normalization")
    parser.add_argument('--shuffle',
                        action='store_true',
                        default=True,
                        help="Shuffle dataset")
    parser.add_argument('--file-dir',
                        type=str,
                        required=False,
                        default=data_path,
                        help="Input file directory")
    parser.add_argument('--train-ratio',
                        type=float,
                        default=75,
                        help="Training ratio (0, 100)")
    parser.add_argument('--valid-ratio',
                        type=float,
                        default=12.5,
                        help="Validation ratio (0, 100)")
    parser.add_argument('--class-weight-balanced',
                        action='store_true',
                        default=class_weight_balanced,
                        help="Adjust weights inversely proportional"
                        " to class frequencies in the input data")
    parser.add_argument('--use-vertex-feature',
                        action='store_true',
                        default=use_vertex_feature,
                        help="Whether to use vertices' structural features")
    parser.add_argument('--sequence-size',
                        type=int,
                        default=16,
                        help="Sequence size (only useful for pscn)")
    parser.add_argument('--neighbor-size',
                        type=int,
                        default=5,
                        help="Neighborhood size (only useful for pscn)")

    args = parser.parse_args()
    args.cuda = not args.no_cuda and torch.cuda.is_available()

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    tensorboard_log_dir = 'tensorboard/%s_%s' % (args.model,
                                                 args.tensorboard_log)
    os.makedirs(tensorboard_log_dir, exist_ok=True)
    shutil.rmtree(tensorboard_log_dir)
    tensorboard_logger.configure(tensorboard_log_dir)
    logger.info('tensorboard logging to %s', tensorboard_log_dir)

    # adj N*n*n
    # feature N*n*f
    # labels N*n*c
    # Load data
    # vertex: vertex id in global network N*n

    if args.model == "pscn":
        influence_dataset = PatchySanDataSet(args.file_dir,
                                             args.dim,
                                             args.seed,
                                             args.shuffle,
                                             args.model,
                                             sequence_size=args.sequence_size,
                                             stride=1,
                                             neighbor_size=args.neighbor_size)
    else:
        influence_dataset = InfluenceDataSet(args.file_dir, args.dim,
                                             args.seed, args.shuffle,
                                             args.model)

    N = len(influence_dataset)
    n_classes = 2
    class_weight = influence_dataset.get_class_weight() \
        if args.class_weight_balanced else torch.ones(n_classes)
    logger.info("class_weight=%.2f:%.2f", class_weight[0], class_weight[1])

    feature_dim = influence_dataset.get_feature_dimension()
    n_units = [feature_dim] + [
        int(x) for x in args.hidden_units.strip().split(",")
    ] + [n_classes]
    logger.info("feature dimension=%d", feature_dim)
    logger.info("number of classes=%d", n_classes)

    train_start, valid_start, test_start = \
        0, int(N * args.train_ratio / 100), int(N * (args.train_ratio + args.valid_ratio) / 100)
    train_loader = DataLoader(influence_dataset,
                              batch_size=args.batch,
                              sampler=ChunkSampler(valid_start - train_start,
                                                   0))
    valid_loader = DataLoader(influence_dataset,
                              batch_size=args.batch,
                              sampler=ChunkSampler(test_start - valid_start,
                                                   valid_start))
    test_loader = DataLoader(influence_dataset,
                             batch_size=args.batch,
                             sampler=ChunkSampler(N - test_start, test_start))

    # Model and optimizer
    if args.model == "gcn":
        model = BatchGCN(
            pretrained_emb=influence_dataset.get_embedding(),
            vertex_feature=influence_dataset.get_vertex_features(),
            use_vertex_feature=args.use_vertex_feature,
            n_units=n_units,
            dropout=args.dropout,
            instance_normalization=args.instance_normalization)
    elif args.model == "gnn_sum":
        model = BatchGNNSUM(
            pretrained_emb=influence_dataset.get_embedding(),
            vertex_feature=influence_dataset.get_vertex_features(),
            use_vertex_feature=args.use_vertex_feature,
            n_units=n_units,
            dropout=args.dropout,
            instance_normalization=args.instance_normalization,
            model_type=args.model_type)
    elif args.model == "gnn_mean":
        model = BatchGNNMEAN(
            pretrained_emb=influence_dataset.get_embedding(),
            vertex_feature=influence_dataset.get_vertex_features(),
            use_vertex_feature=args.use_vertex_feature,
            n_units=n_units,
            dropout=args.dropout,
            instance_normalization=args.instance_normalization,
            model_type=args.model_type)
    elif args.model == "gat":
        n_heads = [int(x) for x in args.heads.strip().split(",")]
        model = BatchGAT(
            pretrained_emb=influence_dataset.get_embedding(),
            vertex_feature=influence_dataset.get_vertex_features(),
            use_vertex_feature=args.use_vertex_feature,
            n_units=n_units,
            n_heads=n_heads,
            dropout=args.dropout,
            instance_normalization=args.instance_normalization)
    elif args.model == "pscn":
        model = BatchPSCN(
            pretrained_emb=influence_dataset.get_embedding(),
            vertex_feature=influence_dataset.get_vertex_features(),
            use_vertex_feature=args.use_vertex_feature,
            n_units=n_units,
            dropout=args.dropout,
            instance_normalization=args.instance_normalization,
            sequence_size=args.sequence_size,
            neighbor_size=args.neighbor_size)
    else:
        raise NotImplementedError

    if args.cuda:
        model.cuda()
        class_weight = class_weight.cuda()

    params = [{
        'params':
        filter(lambda p: p.requires_grad, model.parameters())
        if args.model == "pscn" else model.layer_stack.parameters()
    }]

    optimizer = optim.Adagrad(params,
                              lr=args.lr,
                              weight_decay=args.weight_decay)

    t_total = time.time()
    logger.info("training...")
    for epoch in range(args.epochs):
        train(model, args, class_weight, logger, optimizer, epoch,
              train_loader, valid_loader, test_loader)
    logger.info("optimization Finished!")
    logger.info("total time elapsed: {:.4f}s".format(time.time() - t_total))

    logger.info("retrieve best threshold...")
    best_thr = evaluate(model,
                        args,
                        class_weight,
                        logger,
                        args.epochs,
                        valid_loader,
                        return_best_thr=True,
                        log_desc='valid_')

    # Testing
    logger.info("testing...")
    evaluate(model,
             args,
             class_weight,
             logger,
             args.epochs,
             test_loader,
             thr=best_thr,
             log_desc='test_')

    print("EXPERIMENT END------", model_name, model_type, hidden_units)