Example #1
0
def setup_logger(args):
    log_file = args.log_file
    tensorboard_log_dir = 'tensorboard_' + args.name
    shutil.rmtree(tensorboard_log_dir)
    if args.log_file is None:
        if args.name == '':
            log_file = 'train.log'
        else:
            log_file = args.name + '.log'

    print('Logging to: ' + log_file)

    logging.basicConfig(filename=log_file, level=logging.INFO)
    tensorboard_logger.configure(tensorboard_log_dir)
Example #2
0
    def __init__(self, args, net, G_data):
        self.args = args
        self.net = net
        self.feat_dim = G_data.feat_dim
        # self.fold_idx = G_data.fold_idx
        self.fold_idx = 0
        self.init(args, G_data.train_gs, G_data.valid_gs, G_data.test_gs)
        if torch.cuda.is_available():
            self.net.cuda()

        tensorboard_log_dir = 'tensorboard/%s_%s_%s' % (
            "graph-u-net", args.data, args.label_type)
        os.makedirs(tensorboard_log_dir, exist_ok=True)
        shutil.rmtree(tensorboard_log_dir)
        tensorboard_logger.configure(tensorboard_log_dir)
        logger.info('tensorboard logging to %s', tensorboard_log_dir)
Example #3
0
                    type=int,
                    default=5,
                    help="Neighborhood size (only useful for pscn)")

args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()

np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)

tensorboard_log_dir = 'tensorboard/%s_%s' % (args.model, args.tensorboard_log)
os.makedirs(tensorboard_log_dir, exist_ok=True)
shutil.rmtree(tensorboard_log_dir)
tensorboard_logger.configure(tensorboard_log_dir)
logger.info('tensorboard logging to %s', tensorboard_log_dir)

# adj N*n*n
# feature N*n*f
# labels N*n*c
# Load data
# vertex: vertex id in global network N*n

if args.model == "pscn":
    influence_dataset = PatchySanDataSet(args.file_dir,
                                         args.dim,
                                         args.seed,
                                         args.shuffle,
                                         args.model,
                                         sequence_size=args.sequence_size,
Example #4
0
File: train.py Project: MH-0/RPGAE
def train_evaluate_model(data_path, model_name, model_type, hidden_units,
                         instance_normalization, class_weight_balanced,
                         use_vertex_feature):
    print("EXPERIMENT START------", model_name, model_type, hidden_units)

    logger = logging.getLogger(__name__)
    logging.basicConfig(level=logging.INFO,
                        format='%(asctime)s %(message)s')  # include timestamp

    # Training settings
    parser = argparse.ArgumentParser()
    parser.add_argument('--tensorboard-log',
                        type=str,
                        default='',
                        help="name of this run")
    parser.add_argument('--model',
                        type=str,
                        default=model_name,
                        help="models used")
    parser.add_argument('--model_type',
                        type=str,
                        default=model_type,
                        help="model type used")
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='Disables CUDA training.')
    parser.add_argument('--seed', type=int, default=42, help='Random seed.')
    parser.add_argument('--epochs',
                        type=int,
                        default=500,
                        help='Number of epochs to train.')
    parser.add_argument('--lr',
                        type=float,
                        default=0.1,
                        help='Initial learning rate.')
    parser.add_argument('--weight-decay',
                        type=float,
                        default=5e-4,
                        help='Weight decay (L2 loss on parameters).')
    parser.add_argument('--dropout',
                        type=float,
                        default=0.2,
                        help='Dropout rate (1 - keep probability).')
    parser.add_argument(
        '--hidden-units',
        type=str,
        default=hidden_units,
        help="Hidden units in each hidden layer, splitted with comma")
    parser.add_argument('--heads',
                        type=str,
                        default="1,1,1",
                        help="Heads in each layer, splitted with comma")
    parser.add_argument('--batch', type=int, default=1024, help="Batch size")
    parser.add_argument('--dim',
                        type=int,
                        default=64,
                        help="Embedding dimension")
    parser.add_argument('--check-point',
                        type=int,
                        default=10,
                        help="Eheck point")
    parser.add_argument('--instance-normalization',
                        action='store_true',
                        default=instance_normalization,
                        help="Enable instance normalization")
    parser.add_argument('--shuffle',
                        action='store_true',
                        default=True,
                        help="Shuffle dataset")
    parser.add_argument('--file-dir',
                        type=str,
                        required=False,
                        default=data_path,
                        help="Input file directory")
    parser.add_argument('--train-ratio',
                        type=float,
                        default=75,
                        help="Training ratio (0, 100)")
    parser.add_argument('--valid-ratio',
                        type=float,
                        default=12.5,
                        help="Validation ratio (0, 100)")
    parser.add_argument('--class-weight-balanced',
                        action='store_true',
                        default=class_weight_balanced,
                        help="Adjust weights inversely proportional"
                        " to class frequencies in the input data")
    parser.add_argument('--use-vertex-feature',
                        action='store_true',
                        default=use_vertex_feature,
                        help="Whether to use vertices' structural features")
    parser.add_argument('--sequence-size',
                        type=int,
                        default=16,
                        help="Sequence size (only useful for pscn)")
    parser.add_argument('--neighbor-size',
                        type=int,
                        default=5,
                        help="Neighborhood size (only useful for pscn)")

    args = parser.parse_args()
    args.cuda = not args.no_cuda and torch.cuda.is_available()

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    tensorboard_log_dir = 'tensorboard/%s_%s' % (args.model,
                                                 args.tensorboard_log)
    os.makedirs(tensorboard_log_dir, exist_ok=True)
    shutil.rmtree(tensorboard_log_dir)
    tensorboard_logger.configure(tensorboard_log_dir)
    logger.info('tensorboard logging to %s', tensorboard_log_dir)

    # adj N*n*n
    # feature N*n*f
    # labels N*n*c
    # Load data
    # vertex: vertex id in global network N*n

    if args.model == "pscn":
        influence_dataset = PatchySanDataSet(args.file_dir,
                                             args.dim,
                                             args.seed,
                                             args.shuffle,
                                             args.model,
                                             sequence_size=args.sequence_size,
                                             stride=1,
                                             neighbor_size=args.neighbor_size)
    else:
        influence_dataset = InfluenceDataSet(args.file_dir, args.dim,
                                             args.seed, args.shuffle,
                                             args.model)

    N = len(influence_dataset)
    n_classes = 2
    class_weight = influence_dataset.get_class_weight() \
        if args.class_weight_balanced else torch.ones(n_classes)
    logger.info("class_weight=%.2f:%.2f", class_weight[0], class_weight[1])

    feature_dim = influence_dataset.get_feature_dimension()
    n_units = [feature_dim] + [
        int(x) for x in args.hidden_units.strip().split(",")
    ] + [n_classes]
    logger.info("feature dimension=%d", feature_dim)
    logger.info("number of classes=%d", n_classes)

    train_start, valid_start, test_start = \
        0, int(N * args.train_ratio / 100), int(N * (args.train_ratio + args.valid_ratio) / 100)
    train_loader = DataLoader(influence_dataset,
                              batch_size=args.batch,
                              sampler=ChunkSampler(valid_start - train_start,
                                                   0))
    valid_loader = DataLoader(influence_dataset,
                              batch_size=args.batch,
                              sampler=ChunkSampler(test_start - valid_start,
                                                   valid_start))
    test_loader = DataLoader(influence_dataset,
                             batch_size=args.batch,
                             sampler=ChunkSampler(N - test_start, test_start))

    # Model and optimizer
    if args.model == "gcn":
        model = BatchGCN(
            pretrained_emb=influence_dataset.get_embedding(),
            vertex_feature=influence_dataset.get_vertex_features(),
            use_vertex_feature=args.use_vertex_feature,
            n_units=n_units,
            dropout=args.dropout,
            instance_normalization=args.instance_normalization)
    elif args.model == "gnn_sum":
        model = BatchGNNSUM(
            pretrained_emb=influence_dataset.get_embedding(),
            vertex_feature=influence_dataset.get_vertex_features(),
            use_vertex_feature=args.use_vertex_feature,
            n_units=n_units,
            dropout=args.dropout,
            instance_normalization=args.instance_normalization,
            model_type=args.model_type)
    elif args.model == "gnn_mean":
        model = BatchGNNMEAN(
            pretrained_emb=influence_dataset.get_embedding(),
            vertex_feature=influence_dataset.get_vertex_features(),
            use_vertex_feature=args.use_vertex_feature,
            n_units=n_units,
            dropout=args.dropout,
            instance_normalization=args.instance_normalization,
            model_type=args.model_type)
    elif args.model == "gat":
        n_heads = [int(x) for x in args.heads.strip().split(",")]
        model = BatchGAT(
            pretrained_emb=influence_dataset.get_embedding(),
            vertex_feature=influence_dataset.get_vertex_features(),
            use_vertex_feature=args.use_vertex_feature,
            n_units=n_units,
            n_heads=n_heads,
            dropout=args.dropout,
            instance_normalization=args.instance_normalization)
    elif args.model == "pscn":
        model = BatchPSCN(
            pretrained_emb=influence_dataset.get_embedding(),
            vertex_feature=influence_dataset.get_vertex_features(),
            use_vertex_feature=args.use_vertex_feature,
            n_units=n_units,
            dropout=args.dropout,
            instance_normalization=args.instance_normalization,
            sequence_size=args.sequence_size,
            neighbor_size=args.neighbor_size)
    else:
        raise NotImplementedError

    if args.cuda:
        model.cuda()
        class_weight = class_weight.cuda()

    params = [{
        'params':
        filter(lambda p: p.requires_grad, model.parameters())
        if args.model == "pscn" else model.layer_stack.parameters()
    }]

    optimizer = optim.Adagrad(params,
                              lr=args.lr,
                              weight_decay=args.weight_decay)

    t_total = time.time()
    logger.info("training...")
    for epoch in range(args.epochs):
        train(model, args, class_weight, logger, optimizer, epoch,
              train_loader, valid_loader, test_loader)
    logger.info("optimization Finished!")
    logger.info("total time elapsed: {:.4f}s".format(time.time() - t_total))

    logger.info("retrieve best threshold...")
    best_thr = evaluate(model,
                        args,
                        class_weight,
                        logger,
                        args.epochs,
                        valid_loader,
                        return_best_thr=True,
                        log_desc='valid_')

    # Testing
    logger.info("testing...")
    evaluate(model,
             args,
             class_weight,
             logger,
             args.epochs,
             test_loader,
             thr=best_thr,
             log_desc='test_')

    print("EXPERIMENT END------", model_name, model_type, hidden_units)