예제 #1
0
파일: main.py 프로젝트: lehaifeng/T-GCN
def get_model(args, dm):
    model = None
    if args.model_name == "GCN":
        model = models.GCN(adj=dm.adj, input_dim=args.seq_len, output_dim=args.hidden_dim)
    if args.model_name == "GRU":
        model = models.GRU(input_dim=dm.adj.shape[0], hidden_dim=args.hidden_dim)
    if args.model_name == "TGCN":
        model = models.TGCN(adj=dm.adj, hidden_dim=args.hidden_dim)
    return model
예제 #2
0
파일: main.py 프로젝트: zh-brilliant/T-GCN
def get_model(args, dm):
    model = None
    if args.model_name == 'GCN':
        model = models.GCN(adj=dm.adj,
                           input_dim=args.seq_len,
                           output_dim=args.hidden_dim)
    if args.model_name == 'GRU':
        model = models.GRU(input_dim=dm.adj.shape[0],
                           hidden_dim=args.hidden_dim)
    if args.model_name == 'TGCN':
        model = models.TGCN(adj=dm.adj,
                            hidden_dim=args.hidden_dim,
                            loss=args.loss)
    return model
예제 #3
0
# create folder to save local files
try:
    os.makedirs('models')
except FileExistsError:
    pass

# read data and split to train, validation
data = read_data(min_atom=5, max_atom=50)
train_data, val_data = train_test_split(data, test_size=.2, shuffle=True)

args.dim_af = train_data[0]['x'].shape[1]

# model, optimizer, criterion
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = models.GCN(args).to(device)
optimizer = optim.Adam(model.parameters(),
                       lr=args.lr,
                       weight_decay=args.weight_decay)
criterion = nn.BCELoss()
print(model)

train_dataset = MolDataSet(train_data)
train_dataloader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              collate_fn=dict_collate_fn,
                              shuffle=True)
val_dataset = MolDataSet(val_data)
val_dataloader = DataLoader(val_dataset,
                            batch_size=args.batch_size,
                            collate_fn=dict_collate_fn,
예제 #4
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--gpu', default=0, type=int)
    parser.add_argument('--model_arch',
                        default='conv4',
                        choices=['conv4', 'resnet10', 'resnet18'],
                        type=str)
    # parser.add_argument('--attention', action='store_true')
    parser.add_argument('--start_epoch', default=1, type=int)
    parser.add_argument('--num_epoch', default=90, type=int)
    parser.add_argument('--learning_rate', default=0.01, type=float)
    parser.add_argument('--scheduler_milestones', nargs='+', type=int)
    parser.add_argument('--alpha', default=1, type=float)
    parser.add_argument('--beta', default=1, type=float)
    parser.add_argument('--gamma', default=0.5, type=float)
    parser.add_argument('--model_saving_rate', default=30, type=int)
    parser.add_argument('--train', action='store_true')
    parser.add_argument('--support_groups', default=1000, type=int)
    parser.add_argument('--evaluate', action='store_true')
    parser.add_argument('--evaluation_rate', default=10, type=int)
    parser.add_argument('--model_dir', type=str)
    parser.add_argument('--checkpoint', action='store_true')
    parser.add_argument('--normalize', action='store_true')
    parser.add_argument('--save_settings', action='store_true')
    parser.add_argument('--layer', default=4, type=int)
    parser.add_argument('--fusion_method', default='sum', type=str)
    parser.add_argument('--lamda', default=0, type=float)
    # parser.add_argument('--gcn_path', type=str)
    # parser.add_argument('--img_encoder_path', type=str)

    args = parser.parse_args()

    device = torch.device(f'cuda:{args.gpu}')
    model_arch = args.model_arch
    # attention = args.attention
    learning_rate = args.learning_rate
    alpha = args.alpha
    beta = args.beta
    gamma = args.gamma
    start_epoch = args.start_epoch
    num_epoch = args.num_epoch
    model_saving_rate = args.model_saving_rate
    toTrain = args.train
    toEvaluate = args.evaluate
    evaluation_rate = args.evaluation_rate
    checkpoint = args.checkpoint
    normalize = args.normalize
    scheduler_milestones = args.scheduler_milestones
    save_settings = args.save_settings
    support_groups = args.support_groups
    fusion_method = args.fusion_method
    lamda = args.lamda

    # gcn_path = args.gcn_path
    # img_encoder_path = args.img_encoder_path

    # ------------------------------- #
    # Generate folder
    # ------------------------------- #
    if checkpoint:
        model_dir = f'./training_models/{args.model_dir}'
    else:
        model_dir = f'./training_models/{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}'
        os.makedirs(model_dir)

    # ------------------------------- #
    # Config logger
    # ------------------------------- #
    train_logger = setup_logger('train_logger', f'{model_dir}/train_all.log')
    result_logger = setup_logger('result_logger',
                                 f'{model_dir}/result_all.log')
    if save_settings:
        # ------------------------------- #
        # Saving training parameters
        # ------------------------------- #
        result_logger.info(f'Model: {model_arch}')
        result_logger.info(f'Fusion Method: {fusion_method}; Lamda: {lamda}')
        result_logger.info(f'Attention Layer: {args.layer}')
        result_logger.info(f'Learning rate: {learning_rate}')
        result_logger.info(f'Alpha: {alpha} Beta: {beta} Gamma: {gamma}')
        # result_logger.info(f'alpha: {alpha}')
        result_logger.info(f'Normalize feature vector: {normalize}')

    # ------------------------------- #
    # Load extracted knowledge graph
    # ------------------------------- #
    knowledge_graph = Graph()
    classFile_to_superclasses, superclassID_to_wikiID =\
        knowledge_graph.class_file_to_superclasses(1, [1,2])
    nodes = knowledge_graph.nodes
    # import ipdb; ipdb.set_trace()

    layer = 2
    layer_nums = [768, 2048, 1600]
    edges = knowledge_graph.edges

    cat_feature = 1600
    final_feature = 1024

    ####################
    # Prepare Data Set #
    ####################
    print('preparing dataset')
    base_cls, val_cls, support_cls = get_splits()

    base = MiniImageNet('base', base_cls, val_cls, support_cls,
                        classFile_to_superclasses)
    base_loader = DataLoader(base, batch_size=256, shuffle=True, num_workers=4)

    support = MiniImageNet('support',
                           base_cls,
                           val_cls,
                           support_cls,
                           classFile_to_superclasses,
                           eval=True)
    support_loader_1 = DataLoader(support,
                                  batch_sampler=SupportingSetSampler(
                                      support, 1, 5, 15, support_groups),
                                  num_workers=4)
    support_loader_5 = DataLoader(support,
                                  batch_sampler=SupportingSetSampler(
                                      support, 5, 5, 15, support_groups),
                                  num_workers=4)

    #########
    # Model #
    #########
    # sentence transformer
    sentence_transformer = SentenceTransformer(
        'paraphrase-distilroberta-base-v1')

    # image encoder
    if model_arch == 'conv4':
        img_encoder = models.Conv4Attension(len(base_cls),
                                            len(superclassID_to_wikiID))

    if model_arch == 'resnet10':
        img_encoder = models.resnet10(len(base_cls),
                                      len(superclassID_to_wikiID))

    if model_arch == 'resnet18':
        img_encoder = models.resnet18(len(base_cls),
                                      len(superclassID_to_wikiID))

    # img_encoder.load_state_dict(torch.load(f'{model_dir}/{img_encoder_path}'))
    # img_encoder.to(device)
    # img_encoder.eval()

    # knowledge graph encoder
    GCN = models.GCN(layer, layer_nums, edges)
    # GCN.load_state_dict(torch.load(f'{model_dir}/{gcn_path}'))
    # GCN.to(device)
    # GCN.eval()

    # total model
    model = models.FSKG(cat_feature, final_feature, img_encoder, GCN,
                        len(base_cls), lamda)
    model.to(device)

    # loss function and optimizer
    criterion = loss_fn(alpha, beta, gamma, device)
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=learning_rate,
                                momentum=0.9,
                                weight_decay=1e-4,
                                nesterov=True)
    scheduler = MultiStepLR(optimizer,
                            milestones=scheduler_milestones,
                            gamma=0.1)

    if save_settings:
        result_logger.info(
            'optimizer: torch.optim.SGD(model.parameters(), '
            f'lr={learning_rate}, momentum=0.9, weight_decay=1e-4, nesterov=True)'
        )
        result_logger.info(
            f'scheduler: MultiStepLR(optimizer, milestones={scheduler_milestones}, gamma=0.1)\n'
        )
        # result_logger.info('='*40+'Results Below'+'='*40+'\n')

    if checkpoint:
        print('load model...')
        model.load_state_dict(
            torch.load(f'{model_dir}/FSKG_{start_epoch-1}.pth'))
        model.to(device)

        # for _ in range(start_epoch - 1):
        #     scheduler.step()

    # ---------------------------------------- #
    # Graph convolution to get kg embeddings
    # ---------------------------------------- #

    # encode node description
    desc_embeddings = knowledge_graph.encode_desc(sentence_transformer).to(
        device)

    # start graph convolution
    # import ipdb; ipdb.set_trace()
    # kg_embeddings = GCN(desc_embeddings)
    # kg_embeddings = kg_embeddings.to('cpu')

    classFile_to_wikiID = get_classFile_to_wikiID()
    # train_class_name_to_id = base.class_name_to_id
    train_id_to_class_name = base.id_to_class_name
    # eval_class_name_to_id = support.class_name_to_id
    eval_id_to_class_name = support.id_to_class_name

    # ------------------------------- #
    # Start to train
    # ------------------------------- #
    if toTrain:
        for epoch in range(start_epoch, start_epoch + num_epoch):
            model.train()
            train(model, img_encoder, normalize, base_loader, optimizer,
                  criterion, epoch, start_epoch + num_epoch - 1, device,
                  train_logger, nodes, desc_embeddings, train_id_to_class_name,
                  classFile_to_wikiID)
            scheduler.step()

            if epoch % model_saving_rate == 0:
                torch.save(model.state_dict(), f'{model_dir}/FSKG_{epoch}.pth')

                # ------------------------------- #
                # Evaluate current model
                # ------------------------------- #
            if toEvaluate:
                if epoch % evaluation_rate == 0:
                    evaluate(model, normalize, epoch, support_loader_1, 1, 5,
                             15, device, result_logger, nodes, desc_embeddings,
                             eval_id_to_class_name, classFile_to_wikiID)
                    evaluate(model, normalize, epoch, support_loader_5, 5, 5,
                             15, device, result_logger, nodes, desc_embeddings,
                             eval_id_to_class_name, classFile_to_wikiID)

    else:
        # pass
        if toEvaluate:
            evaluate(model, normalize, 30, support_loader_1, 1, 5, 15, device,
                     result_logger, nodes, desc_embeddings,
                     eval_id_to_class_name, classFile_to_wikiID)
            evaluate(model, normalize, 30, support_loader_5, 5, 5, 15, device,
                     result_logger, nodes, desc_embeddings,
                     eval_id_to_class_name, classFile_to_wikiID)
    result_logger.info('=' * 140)
예제 #5
0
def train():
    parser = argparse.ArgumentParser()
    parser.add_argument('--gpu', default=0, type=int)
    parser.add_argument('--model_arch',
                        default='conv4',
                        choices=['conv4', 'resnet10', 'resnet18'],
                        type=str)
    parser.add_argument('--start_epoch', default=1, type=int)
    parser.add_argument('--num_epoch', default=90, type=int)
    parser.add_argument('--learning_rate', default=0.01, type=float)
    parser.add_argument('--model_saving_rate', default=30, type=int)
    parser.add_argument('--train', action='store_true')
    # parser.add_argument('--support_groups', default=10000, type=int)
    parser.add_argument('--evaluate', action='store_true')
    parser.add_argument('--evaluation_rate', default=10, type=int)
    parser.add_argument('--model_dir', type=str)
    parser.add_argument('--img_encoder_path', type=str)
    parser.add_argument('--checkpoint', action='store_true')
    parser.add_argument('--normalize', action='store_true')
    parser.add_argument('--save_settings', action='store_true')
    parser.add_argument('--layer', default=4, type=int)
    parser.add_argument('--classifiers_path', action='store_true')
    parser.add_argument('--optimizer', default='SGD', type=str)
    # parser.add_argument('--scheduler_milestones', nargs='+', type=int)

    args = parser.parse_args()

    device = torch.device(f'cuda:{args.gpu}')
    model_arch = args.model_arch
    learning_rate = args.learning_rate
    start_epoch = args.start_epoch
    num_epoch = args.num_epoch
    model_saving_rate = args.model_saving_rate
    # toTrain = args.train
    # toEvaluate = args.evaluate
    evaluation_rate = args.evaluation_rate
    checkpoint = args.checkpoint
    # scheduler_milestones = args.scheduler_milestones
    save_settings = args.save_settings
    model_dir = f'./training_models/{args.model_dir}'
    img_encoder_path = f'{model_dir}/{args.img_encoder_path}'
    classifiers_path = args.classifiers_path
    normalize = args.normalize

    # ------------------------------- #
    # Config logger
    # ------------------------------- #
    train_logger = setup_logger('train_logger', f'{model_dir}/gcn_train.log')
    if save_settings:
        # ------------------------------- #
        # Saving training parameters
        # ------------------------------- #
        train_logger.info(f'{model_arch} Model: {img_encoder_path}')
        train_logger.info(f'Attention Layer: args.layer')
        train_logger.info(f'Learning rate: {learning_rate}')
        train_logger.info(f'Optimizer: {args.optimizer}')

    # ------------------------------- #
    # Load extracted knowledge graph
    # ------------------------------- #
    knowledge_graph = Graph()
    classFile_to_superclasses, superclassID_to_wikiID =\
        knowledge_graph.class_file_to_superclasses(1, [1,2])
    edges = knowledge_graph.edges
    nodes = knowledge_graph.nodes

    ####################
    # Prepare Data Set #
    ####################
    print('preparing dataset')
    base_cls, val_cls, support_cls = get_splits()
    base = MiniImageNet('base', base_cls, val_cls, support_cls,
                        classFile_to_superclasses)
    base_loader = DataLoader(base,
                             batch_size=256,
                             shuffle=False,
                             num_workers=4)

    # ------------------------------- #
    # Load image encoder model
    # ------------------------------- #
    # image encoder
    if model_arch == 'conv4':
        img_encoder = models.Conv4Attension(len(base_cls),
                                            len(superclassID_to_wikiID))

    if model_arch == 'resnet10':
        img_encoder = models.resnet10(len(base_cls),
                                      len(superclassID_to_wikiID))

    if model_arch == 'resnet18':
        img_encoder = models.resnet18(len(base_cls),
                                      len(superclassID_to_wikiID))

    img_encoder.load_state_dict(torch.load(f'{img_encoder_path}'))
    img_encoder.to(device)

    img_feature_dim = img_encoder.dim_feature

    # ------------------------------- #
    # get class classifiers
    # ------------------------------- #

    if classifiers_path:
        with open(f'{model_dir}/base_classifiers.pkl', 'rb') as f:
            classifiers = pickle.load(f)
    else:
        classifiers = get_classifier(img_encoder, img_feature_dim,
                                     len(base_cls), base_loader, 'base',
                                     normalize, model_dir, device)

    # import ipdb; ipdb.set_trace()

    # ------------------------------- #
    # Init GCN model
    # ------------------------------- #
    layer = 2
    layer_nums = [768, 2048, img_feature_dim]
    layer_nums_str = "".join([str(a) + ' ' for a in layer_nums])
    if save_settings:
        train_logger.info(f'GCN layers: {layer_nums_str}')
    GCN = models.GCN(layer, layer_nums, edges)
    # GCN = models.GCN(edges)
    GCN.to(device)
    # import ipdb; ipdb.set_trace()
    # ------------------------------- #
    # Other neccessary parameters
    # ------------------------------- #
    classFile_to_wikiID = get_classFile_to_wikiID()
    base_cls_index = [
        nodes.index(classFile_to_wikiID[base.id_to_class_name[i]])
        for i in range(len(base_cls))
    ]
    # support_cls_index = [nodes.index(classFile_to_wikiID[base.id_to_class_name[i]]) for i in range(len(support_cls))]

    sentence_transformer = SentenceTransformer(
        'paraphrase-distilroberta-base-v1')
    desc_embeddings = knowledge_graph.encode_desc(sentence_transformer)
    desc_embeddings = desc_embeddings.to(device)

    # ------------------------------- #
    # Training settings
    # ------------------------------- #
    # criterion = torch.nn.MSELoss()
    criterion = torch.nn.CosineEmbeddingLoss()
    optimizer = torch.optim.SGD(GCN.parameters(),
                                lr=learning_rate,
                                momentum=0.9,
                                weight_decay=1e-4,
                                nesterov=True)
    # optimizer = torch.optim.Adam(GCN.parameters(), lr=learning_rate, weight_decay=1e-4)

    batch_time = AverageMeter()  # forward prop. + back prop. time
    losses = AverageMeter()  # loss

    GCN.train()
    start = time.time()

    classifiers = classifiers.to(device)

    loss_target = torch.ones(classifiers.shape[0]).to(device)

    for epoch in range(start_epoch, start_epoch + num_epoch):
        base_embeddings = GCN(desc_embeddings)[base_cls_index]
        # import ipdb; ipdb.set_trace()
        # loss = criterion(base_embeddings, classifiers)
        loss = criterion(base_embeddings, classifiers, loss_target)
        loss.backward()
        optimizer.step()

        losses.update(loss.item())
        # print(loss.item())
        batch_time.update(time.time() - start)

        if epoch % 200 == 0:  # print every 30 epoch
            train_logger.info(
                f'[{epoch:3d}/{start_epoch+num_epoch-1}]'
                f' batch_time: {batch_time.avg:.2f} loss: {losses.avg:.3f}')
            batch_time.reset()
            losses.reset()
            start = time.time()
        if epoch % 1000 == 0:
            torch.save(GCN.state_dict(), f'{model_dir}/gcn_{epoch}.pth')

    train_logger.info("=" * 60)
예제 #6
0
def main(**kwargs):

    # load data
    d = data.get_data(os.path.join(kwargs['pdfp'], kwargs['data_train_pkl']),
                      label=kwargs['label'],
                      sample=kwargs['sample'],
                      replicate=kwargs['replicate'],
                      incl_curvature=kwargs['incl_curvature'],
                      load_attn1=kwargs['load_attn1'],
                      load_attn2=kwargs['load_attn2'],
                      modelpkl_fname1=os.path.join(kwargs['pdfp'],
                                                   kwargs['modelpkl_fname1']),
                      modelpkl_fname2=os.path.join(kwargs['pdfp'],
                                                   kwargs['modelpkl_fname2']),
                      preloadn2v=kwargs['preloadn2v'],
                      out_channels=8,
                      heads=8,
                      negative_slope=0.2,
                      dropout=0.4)

    if not kwargs['fastmode']:
        d_val = data.get_data(
            os.path.join(kwargs['pdfp'], kwargs['data_val_pkl']),
            label=kwargs['label'],
            sample=kwargs['sample'],
            replicate=kwargs['replicate'],
            incl_curvature=kwargs['incl_curvature'],
            load_attn1=kwargs['load_attn1'],
            load_attn2=kwargs['load_attn2'],
            modelpkl_fname1=os.path.join(kwargs['pdfp'],
                                         kwargs['modelpkl_fname1']),
            modelpkl_fname2=os.path.join(kwargs['pdfp'],
                                         kwargs['modelpkl_fname2']),
            preloadn2v=kwargs['preloadn2v'],
            out_channels=8,
            heads=8,
            negative_slope=0.2,
            dropout=0.4)

    # data loader for mini-batching
    cd = data.ClusterData(d, num_parts=kwargs['NumParts'])
    cl = data.ClusterLoader(cd, batch_size=kwargs['BatchSize'], shuffle=True)

    if not kwargs['fastmode']:
        cd_val = data.ClusterData(d_val, num_parts=kwargs['NumParts'])
        cl_val = data.ClusterLoader(cd_val,
                                    batch_size=kwargs['BatchSize'],
                                    shuffle=True)

    # pick device
    if False:
        # automate?
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    else:
        device = torch.device(kwargs['Device'])

    # import model
    models.nHiddenUnits = kwargs['nHiddenUnits']
    models.d = d
    models.nHeads = kwargs['nHeads']
    models.alpha = kwargs['alpha']
    models.dropout = kwargs['dropout']

    if 'transformer' in kwargs[
            'model'] and kwargs['model'] != 'GAT_transformer_averaged':
        # need to define model based on max number of connections
        s_max = []
        for batch in cl:
            _, n = batch.edge_index[0].unique(return_counts=True)
            s_max.append(n.max().item())
        s_max = 5 + np.max(s_max)
        models.s_max = s_max
        print('s_max: {}'.format(s_max))

        if kwargs['model'] == 'GCN_transformer':
            model = models.GCN_transformer().to(device)
        elif kwargs['model'] == 'GCN_transformer_mlp':
            model = models.GCN_transformer_mlp().to(device)
        elif kwargs['model'] == 'GAT_transformer':
            model = models.GAT_transformer().to(device)
        elif kwargs['model'] == 'GAT_transformer_mlp':
            model = models.GAT_transformer_mlp().to(device)
        elif kwargs['model'] == 'GAT_transformer_batch':
            model = models.GAT_transformer_batch().to(device)
        elif kwargs['model'] == 'GAT_transformer_mlp_batch':
            model = models.GAT_transformer_mlp_batch().to(device)
        elif kwargs['model'] == 'GCN_transformer_mlp_batch':
            model = models.GCN_transformer_mlp_batch().to(device)

    # specific model names
    elif kwargs['model'] == 'GCN_deepset':
        model = models.GCN_deepset().to(device)
    elif kwargs['model'] == 'GCN_set2set':
        model = models.GCN_set2set().to(device)
    elif kwargs['model'] == 'GAT_deepset':
        model = models.GAT_deepset().to(device)
    elif kwargs['model'] == 'GAT_set2set':
        model = models.GAT_set2set().to(device)
    elif kwargs['model'] == 'GCN':
        model = models.GCN().to(device)
    elif kwargs['model'] == 'GAT':
        model = models.GAT().to(device)
    elif kwargs['model'] == 'GINE':
        model = models.GINE().to(device)
    elif kwargs['model'] == 'edge_cond_conv':
        model = models.edge_cond_conv().to(device)
    elif kwargs['model'] == 'GAT_transformer_averaged':
        model = models.GAT_transformer_averaged().to(device)
    else:
        print(
            'Re-enter model name. Valid ones are (GAT/GCN)(_transformer)(_mlp)(_batch) for last two with transformer'
        )
        exit()

    # set seeds
    random.seed(kwargs['rs'])
    np.random.seed(kwargs['rs'])
    torch.manual_seed(kwargs['rs'])
    if kwargs['Device'] == 'cuda':
        torch.cuda.manual_seed(kwargs['rs'])

    # pick optimizer
    optimizer = torch.optim.Adagrad(model.parameters(),
                                    lr=kwargs['LR'],
                                    weight_decay=kwargs['WeightDecay'])

    # set train module values
    train.model = model
    train.cl = cl
    train.optimizer = optimizer
    train.device = device
    if not kwargs['fastmode']:
        train.cl_val = cl_val
    train.model_name = kwargs['model']
    train.clip = kwargs['clip']
    train.fastmode = kwargs['fastmode']

    # train scheme
    t_total = time.time()
    loss_values = []
    bad_counter = 0
    best = kwargs[
        'nEpochs'] + 1  # np.inf to avoid problems if small epoch number
    best_epoch = 0
    for epoch in range(kwargs['nEpochs']):
        loss_values.append(train.train(epoch))

        if not kwargs['fastmode']:
            torch.save(
                model.state_dict(),
                '{}-{}{}.pkl'.format(epoch, kwargs['sample'],
                                     kwargs['replicate']))

            if loss_values[-1] < best:
                best = loss_values[-1]
                best_epoch = epoch
                bad_counter = 0
            else:
                bad_counter += 1

            if bad_counter == kwargs['patience']:
                break

            files = glob.glob('*-{}{}.pkl'.format(kwargs['sample'],
                                                  kwargs['replicate']))
            for file in files:
                epoch_nb = int(
                    file.split('-{}{}.pkl'.format(kwargs['sample'],
                                                  kwargs['replicate']))[0])
                if epoch_nb < best_epoch:
                    os.remove(file)

        elif epoch == kwargs['nEpochs']:
            torch.save(
                model.state_dict(),
                '{}-{}{}.pkl'.format(epoch, kwargs['sample'],
                                     kwargs['replicate']))

    files = glob.glob('*-{}{}.pkl'.format(kwargs['sample'],
                                          kwargs['replicate']))
    for file in files:
        epoch_nb = int(
            file.split('-{}{}.pkl'.format(kwargs['sample'],
                                          kwargs['replicate']))[0])
        if epoch_nb > best_epoch:
            os.remove(file)

    print('\nOptimization Finished! Best epoch: {}'.format(best_epoch))
    print('Training time elapsed: {}-h:m:s'.format(
        str(datetime.timedelta(seconds=time.time() - t_total))))

    if True:
        # test
        print('\nLoading epoch #{}'.format(best_epoch))

        if True:
            # send model to cpu
            if kwargs['model'] == 'GCN_transformer':
                model = models.GCN_transformer().to(torch.device('cpu'))
            elif kwargs['model'] == 'GCN_transformer_mlp':
                model = models.GCN_transformer_mlp().to(torch.device('cpu'))
            elif kwargs['model'] == 'GAT_transformer':
                model = models.GAT_transformer().to(torch.device('cpu'))
            elif kwargs['model'] == 'GAT_transformer_mlp':
                model = models.GAT_transformer_mlp().to(torch.device('cpu'))
            elif kwargs['model'] == 'GCN':
                model = models.GCN().to(torch.device('cpu'))
            elif kwargs['model'] == 'GAT':
                model = models.GAT().to(torch.device('cpu'))
            elif kwargs['model'] == 'GAT_transformer_batch':
                model = models.GAT_transformer_batch().to(device)
            elif kwargs['model'] == 'GAT_transformer_mlp_batch':
                model = models.GAT_transformer_mlp_batch().to(device)
            elif kwargs['model'] == 'GCN_transformer_mlp_batch':
                model = models.GCN_transformer_mlp_batch().to(device)
            elif kwargs['model'] == 'GCN_deepset':
                model = models.GCN_deepset().to(device)
            elif kwargs['model'] == 'GCN_set2set':
                model = models.GCN_set2set().to(device)
            elif kwargs['model'] == 'GAT_deepset':
                model = models.GAT_deepset().to(device)
            elif kwargs['model'] == 'GAT_set2set':
                model = models.GAT_set2set().to(device)
            elif kwargs['model'] == 'GINE':
                model = models.GINE().to(device)
            elif kwargs['model'] == 'edge_cond_conv':
                model = models.edge_cond_conv().to(device)
            elif kwargs['model'] == 'GAT_transformer_averaged':
                model = models.GAT_transformer_averaged().to(device)

        model.load_state_dict(
            torch.load('{}-{}{}.pkl'.format(best_epoch, kwargs['sample'],
                                            kwargs['replicate']),
                       map_location=torch.device('cpu')))

        train.test_fname = os.path.join(kwargs['pdfp'],
                                        kwargs['data_test_pkl'])
        train.label = kwargs['label']
        train.sample = kwargs['sample']
        train.replicate = kwargs['replicate']
        train.incl_curvature = kwargs['incl_curvature']
        train.load_attn1 = kwargs['load_attn1']
        train.load_attn2 = kwargs['load_attn2']
        train.modelpkl_fname1 = os.path.join(kwargs['pdfp'],
                                             kwargs['modelpkl_fname1'])
        train.modelpkl_fname2 = os.path.join(kwargs['pdfp'],
                                             kwargs['modelpkl_fname2'])
        train.preloadn2v = kwargs['preloadn2v']
        train.model = model
        train.batch_size = kwargs['BatchSize']
        train.num_parts = kwargs['NumParts']

        train.compute_test()
예제 #7
0
    def __init__(self,
                 graph,
                 learning_rate=0.01,
                 epochs=200,
                 hidden1=16,
                 dropout=0.5,
                 weight_decay=5e-4,
                 early_stopping=10,
                 max_degree=3,
                 clf_ratio=0.1):
        """
                        learning_rate: Initial learning rate
                        epochs: Number of epochs to train
                        hidden1: Number of units in hidden layer 1
                        dropout: Dropout rate (1 - keep probability)
                        weight_decay: Weight for L2 loss on embedding matrix
                        early_stopping: Tolerance for early stopping (# of epochs)
                        max_degree: Maximum Chebyshev polynomial degree
        """
        self.graph = graph
        self.clf_ratio = clf_ratio
        self.learning_rate = learning_rate
        self.epochs = epochs
        self.hidden1 = hidden1
        self.dropout = dropout
        self.weight_decay = weight_decay
        self.early_stopping = early_stopping
        self.max_degree = max_degree

        self.preprocess_data()
        self.build_placeholders()
        # Create model
        self.model = models.GCN(self.placeholders,
                                input_dim=self.features[2][1],
                                hidden1=self.hidden1,
                                weight_decay=self.weight_decay,
                                logging=True)
        # Initialize session
        self.sess = tf.Session()
        # Init variables
        self.sess.run(tf.global_variables_initializer())

        cost_val = []

        # Train model
        for epoch in range(self.epochs):

            t = time.time()
            # Construct feed dictionary
            feed_dict = self.construct_feed_dict(self.train_mask)
            feed_dict.update({self.placeholders['dropout']: self.dropout})

            # Training step
            outs = self.sess.run(
                [self.model.opt_op, self.model.loss, self.model.accuracy],
                feed_dict=feed_dict)

            # Validation
            cost, acc, duration = self.evaluate(self.val_mask)
            cost_val.append(cost)

            # Print results
            print("Epoch:", '%04d' % (epoch + 1), "train_loss=",
                  "{:.5f}".format(outs[1]), "train_acc=",
                  "{:.5f}".format(outs[2]), "val_loss=", "{:.5f}".format(cost),
                  "val_acc=", "{:.5f}".format(acc), "time=",
                  "{:.5f}".format(time.time() - t))

            if epoch > self.early_stopping and cost_val[-1] > np.mean(
                    cost_val[-(self.early_stopping + 1):-1]):
                print("Early stopping...")
                break
        print("Optimization Finished!")

        # Testing
        test_cost, test_acc, test_duration = self.evaluate(self.test_mask)
        print("Test set results:", "cost=", "{:.5f}".format(test_cost),
              "accuracy=", "{:.5f}".format(test_acc), "time=",
              "{:.5f}".format(test_duration))
print("2.Data Loading...")
adj, features, labels, idx_train, idx_val, idx_test = load_data(
    path=args.data_path)

adj = adj.to(device)
features = features.to(device)
labels = labels.to(device)
idx_train = idx_train.to(device)
idx_val = idx_val.to(device)
idx_test = idx_test.to(device)

print("3.Creating Model")

gc_net = models.GCN(nfeat=features.shape[1],
                    nhid=args.hidden,
                    nclass=labels.max().item() + 1,
                    dropout=args.dropout).to(device)
# if args.pretrained:
#     print('=> using pre-trained weights for PoseNet')
#     weights = torch.load(args.pretrained)
#     gc_net.load_state_dict(weights['state_dict'], strict=False)
# else:
#     gc_net.init_weights()

print("4. Setting Optimization Solver")
optimizer = torch.optim.Adam(gc_net.parameters(),
                             lr=args.lr,
                             betas=(args.momentum, args.beta),
                             weight_decay=args.weight_decay)

#exp_lr_scheduler_R = lr_scheduler.StepLR(optimizer, step_size=100, gamma=5)
예제 #9
0
pkl = "data.pkl"

with open(pkl, "rb") as f:
    block_adj, block_feat, block_pool, y, train_len, total_len = pickle.load(f)

n_feat = block_feat.size()[1]
n_class = y.max().item() + 1
n_hid = 10
learning_rate = 0.01
weight_decay = 5e-4
dropout = 0.5
n_epochs = 100
train_idx = torch.LongTensor(range(train_len))
test_idx = torch.LongTensor(range(train_len, total_len))

model = models.GCN(nfeat=n_feat, nhid=n_hid, nclass=n_class, dropout=dropout)
optimizer = optim.Adam(model.parameters(),
                       lr=learning_rate,
                       weight_decay=weight_decay)


def train(epoch):
    t = time.time()
    model.train()
    optimizer.zero_grad()
    output = model(block_feat, block_adj, block_pool)
    loss_train = F.nll_loss(output[train_idx], y[train_idx])
    acc_train = util.accuracy(output[train_idx], y[train_idx])
    loss_train.backward()
    optimizer.step()
예제 #10
0
def main(unused_argv):
    """Main function for running experiments."""
    # Load data
    utils.tab_printer(FLAGS.flag_values_dict())
    (full_adj, feats, y_train, y_val, y_test, train_mask, val_mask, test_mask,
     train_data, val_data, test_data,
     num_data) = utils.load_ne_data_transductive_sparse(
         FLAGS.data_prefix, FLAGS.dataset, FLAGS.precalc,
         list(map(float, FLAGS.split)))

    # Partition graph and do preprocessing
    if FLAGS.bsize > 1:  # multi cluster per epoch
        _, parts = partition_utils.partition_graph(full_adj,
                                                   np.arange(num_data),
                                                   FLAGS.num_clusters)

        parts = [np.array(pt) for pt in parts]
    else:
        (parts, features_batches, support_batches, y_train_batches,
         train_mask_batches) = utils.preprocess(full_adj,
                                                feats,
                                                y_train,
                                                train_mask,
                                                np.arange(num_data),
                                                FLAGS.num_clusters,
                                                FLAGS.diag_lambda,
                                                sparse_input=True)
    # valid & test in the same time
    # validation set
    (_, val_features_batches, test_features_batches, val_support_batches,
     y_val_batches, y_test_batches,
     val_mask_batches, test_mask_batches) = utils.preprocess_val_test(
         full_adj, feats, y_val, val_mask, y_test, test_mask,
         np.arange(num_data), FLAGS.num_clusters_val, FLAGS.diag_lambda)

    # (_, val_features_batches, val_support_batches, y_val_batches,
    #  val_mask_batches) = utils.preprocess(full_adj, feats, y_val, val_mask,
    #                                       np.arange(num_data),
    #                                       FLAGS.num_clusters_val,
    #                                       FLAGS.diag_lambda)
    # # test set
    # (_, test_features_batches, test_support_batches, y_test_batches,
    #  test_mask_batches) = utils.preprocess(full_adj, feats, y_test,
    #                                        test_mask, np.arange(num_data),
    #                                        FLAGS.num_clusters_test,
    #                                        FLAGS.diag_lambda)
    idx_parts = list(range(len(parts)))

    # Define placeholders
    placeholders = {
        'support': tf.sparse_placeholder(tf.float32),
        # 'features':
        #     tf.placeholder(tf.float32),
        'features': tf.sparse_placeholder(tf.float32),
        'labels': tf.placeholder(tf.float32, shape=(None, y_train.shape[1])),
        'labels_mask': tf.placeholder(tf.int32),
        'dropout': tf.placeholder_with_default(0., shape=()),
        'fm_dropout': tf.placeholder_with_default(0., shape=()),
        'gat_dropout': tf.placeholder_with_default(0.,
                                                   shape=()),  # gat attn drop
        'num_features_nonzero':
        tf.placeholder(tf.int32)  # helper variable for sparse dropout
    }

    # Create model
    if FLAGS.model == 'gcn':
        model = models.GCN(placeholders,
                           input_dim=feats.shape[1],
                           logging=True,
                           multilabel=FLAGS.multilabel,
                           norm=FLAGS.layernorm,
                           precalc=FLAGS.precalc,
                           num_layers=FLAGS.num_layers,
                           residual=False,
                           sparse_inputs=True)
    elif FLAGS.model == 'gcn_nfm':
        model = models.GCN_NFM(placeholders,
                               input_dim=feats.shape[1],
                               logging=True,
                               multilabel=FLAGS.multilabel,
                               norm=FLAGS.layernorm,
                               precalc=FLAGS.precalc,
                               num_layers=FLAGS.num_layers,
                               residual=False,
                               sparse_inputs=True)
    elif FLAGS.model == 'gat_nfm':
        gat_layers = list(map(int, FLAGS.gat_layers))
        model = models.GAT_NFM(placeholders,
                               input_dim=feats.shape[1],
                               logging=True,
                               multilabel=FLAGS.multilabel,
                               norm=FLAGS.layernorm,
                               precalc=FLAGS.precalc,
                               num_layers=FLAGS.num_layers,
                               residual=False,
                               sparse_inputs=True,
                               gat_layers=gat_layers)
    else:
        raise ValueError(str(FLAGS.model))

    # Initialize session
    sess = tf.Session()

    # Init variables
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    cost_val = []
    acc_val = []
    total_training_time = 0.0
    # Train model
    for epoch in range(FLAGS.epochs):
        t = time.time()
        np.random.shuffle(idx_parts)
        if FLAGS.bsize > 1:
            (features_batches, support_batches, y_train_batches,
             train_mask_batches) = utils.preprocess_multicluster(
                 full_adj, parts, feats, y_train, train_mask,
                 FLAGS.num_clusters, FLAGS.bsize, FLAGS.diag_lambda, True)
            for pid in range(len(features_batches)):
                # Use preprocessed batch data
                features_b = features_batches[pid]
                support_b = support_batches[pid]
                y_train_b = y_train_batches[pid]
                train_mask_b = train_mask_batches[pid]
                # Construct feed dictionary
                feed_dict = utils.construct_feed_dict(features_b, support_b,
                                                      y_train_b, train_mask_b,
                                                      placeholders)
                feed_dict.update({placeholders['dropout']: FLAGS.dropout})
                feed_dict.update(
                    {placeholders['fm_dropout']: FLAGS.fm_dropout})
                feed_dict.update(
                    {placeholders['gat_dropout']: FLAGS.gat_dropout})
                # Training step
                outs = sess.run([model.opt_op, model.loss, model.accuracy],
                                feed_dict=feed_dict)
                # debug
                outs = sess.run([model.opt_op, model.loss, model.accuracy],
                                feed_dict=feed_dict)
        else:
            np.random.shuffle(idx_parts)
            for pid in idx_parts:
                # Use preprocessed batch data
                features_b = features_batches[pid]
                support_b = support_batches[pid]
                y_train_b = y_train_batches[pid]
                train_mask_b = train_mask_batches[pid]
                # Construct feed dictionary
                feed_dict = utils.construct_feed_dict(features_b, support_b,
                                                      y_train_b, train_mask_b,
                                                      placeholders)
                feed_dict.update({placeholders['dropout']: FLAGS.dropout})
                feed_dict.update(
                    {placeholders['fm_dropout']: FLAGS.fm_dropout})
                feed_dict.update(
                    {placeholders['gat_dropout']: FLAGS.gat_dropout})
                # Training step
                outs = sess.run([model.opt_op, model.loss, model.accuracy],
                                feed_dict=feed_dict)

        total_training_time += time.time() - t
        print_str = 'Epoch: %04d ' % (
            epoch + 1) + 'training time: {:.5f} '.format(
                total_training_time) + 'train_acc= {:.5f} '.format(outs[2])

        # Validation
        ## todo: merge validation in train procedure
        if FLAGS.validation:
            cost, acc, micro, macro = evaluate(sess, model,
                                               val_features_batches,
                                               val_support_batches,
                                               y_val_batches, val_mask_batches,
                                               val_data, placeholders)
            cost_val.append(cost)
            acc_val.append(acc)
            print_str += 'val_acc= {:.5f} '.format(
                acc) + 'mi F1= {:.5f} ma F1= {:.5f} '.format(micro, macro)

        # tf.logging.info(print_str)
        print(print_str)

        if epoch > FLAGS.early_stopping and cost_val[-1] > np.mean(
                cost_val[-(FLAGS.early_stopping + 1):-1]):
            tf.logging.info('Early stopping...')
            break

        ### use acc early stopping, lower performance than using loss
        # if epoch > FLAGS.early_stopping and acc_val[-1] < np.mean(
        #     acc_val[-(FLAGS.early_stopping + 1):-1]):
        #   tf.logging.info('Early stopping...')
        #   break

    tf.logging.info('Optimization Finished!')

    # Save model
    saver.save(sess, FLAGS.save_name)

    # Load model (using CPU for inference)
    with tf.device('/cpu:0'):
        sess_cpu = tf.Session(config=tf.ConfigProto(device_count={'GPU': 0}))
        sess_cpu.run(tf.global_variables_initializer())
        saver = tf.train.Saver()
        saver.restore(sess_cpu, FLAGS.save_name)
        # Testing
        test_cost, test_acc, micro, macro = evaluate(
            sess_cpu, model, test_features_batches, val_support_batches,
            y_test_batches, test_mask_batches, test_data, placeholders)
        print_str = 'Test set results: ' + 'cost= {:.5f} '.format(
            test_cost) + 'accuracy= {:.5f} '.format(
                test_acc) + 'mi F1= {:.5f} ma F1= {:.5f}'.format(micro, macro)
        tf.logging.info(print_str)
예제 #11
0
    def __init__(self, args):
        self.args = args
        self.mode = args.mode
        self.epochs = args.epochs
        self.dataset = args.dataset
        self.data_path = args.data_path
        self.train_crop_size = args.train_crop_size
        self.eval_crop_size = args.eval_crop_size
        self.stride = args.stride
        self.batch_size = args.train_batch_size
        self.train_data = AerialDataset(crop_size=self.train_crop_size,
                                        dataset=self.dataset,
                                        data_path=self.data_path,
                                        mode='train')
        self.train_loader = DataLoader(self.train_data,
                                       batch_size=self.batch_size,
                                       shuffle=True,
                                       num_workers=2)
        self.eval_data = AerialDataset(dataset=self.dataset,
                                       data_path=self.data_path,
                                       mode='val')
        self.eval_loader = DataLoader(self.eval_data,
                                      batch_size=1,
                                      shuffle=False,
                                      num_workers=2)

        if self.dataset == 'Potsdam':
            self.num_of_class = 6
            self.epoch_repeat = get_test_times(6000, 6000,
                                               self.train_crop_size,
                                               self.train_crop_size)
        elif self.dataset == 'UDD5':
            self.num_of_class = 5
            self.epoch_repeat = get_test_times(4000, 3000,
                                               self.train_crop_size,
                                               self.train_crop_size)
        elif self.dataset == 'UDD6':
            self.num_of_class = 6
            self.epoch_repeat = get_test_times(4000, 3000,
                                               self.train_crop_size,
                                               self.train_crop_size)
        else:
            raise NotImplementedError

        if args.model == 'FCN':
            self.model = models.FCN8(num_classes=self.num_of_class)
        elif args.model == 'DeepLabV3+':
            self.model = models.DeepLab(num_classes=self.num_of_class,
                                        backbone='resnet')
        elif args.model == 'GCN':
            self.model = models.GCN(num_classes=self.num_of_class)
        elif args.model == 'UNet':
            self.model = models.UNet(num_classes=self.num_of_class)
        elif args.model == 'ENet':
            self.model = models.ENet(num_classes=self.num_of_class)
        elif args.model == 'D-LinkNet':
            self.model = models.DinkNet34(num_classes=self.num_of_class)
        else:
            raise NotImplementedError

        if args.loss == 'CE':
            self.criterion = CrossEntropyLoss2d()
        elif args.loss == 'LS':
            self.criterion = LovaszSoftmax()
        elif args.loss == 'F':
            self.criterion = FocalLoss()
        elif args.loss == 'CE+D':
            self.criterion = CE_DiceLoss()
        else:
            raise NotImplementedError

        self.schedule_mode = args.schedule_mode
        self.optimizer = opt.AdamW(self.model.parameters(), lr=args.lr)
        if self.schedule_mode == 'step':
            self.scheduler = opt.lr_scheduler.StepLR(self.optimizer,
                                                     step_size=30,
                                                     gamma=0.1)
        elif self.schedule_mode == 'miou' or self.schedule_mode == 'acc':
            self.scheduler = opt.lr_scheduler.ReduceLROnPlateau(self.optimizer,
                                                                mode='max',
                                                                patience=10,
                                                                factor=0.1)
        elif self.schedule_mode == 'poly':
            iters_per_epoch = len(self.train_loader)
            self.scheduler = Poly(self.optimizer,
                                  num_epochs=args.epochs,
                                  iters_per_epoch=iters_per_epoch)
        else:
            raise NotImplementedError

        self.evaluator = Evaluator(self.num_of_class)

        self.model = nn.DataParallel(self.model)

        self.cuda = args.cuda
        if self.cuda is True:
            self.model = self.model.cuda()

        self.resume = args.resume
        self.finetune = args.finetune
        assert not (self.resume != None and self.finetune != None)

        if self.resume != None:
            print("Loading existing model...")
            if self.cuda:
                checkpoint = torch.load(args.resume)
            else:
                checkpoint = torch.load(args.resume, map_location='cpu')
            self.model.load_state_dict(checkpoint['parameters'])
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.scheduler.load_state_dict(checkpoint['scheduler'])
            self.start_epoch = checkpoint['epoch'] + 1
            #start from next epoch
        elif self.finetune != None:
            print("Loading existing model...")
            if self.cuda:
                checkpoint = torch.load(args.finetune)
            else:
                checkpoint = torch.load(args.finetune, map_location='cpu')
            self.model.load_state_dict(checkpoint['parameters'])
            self.start_epoch = checkpoint['epoch'] + 1
        else:
            self.start_epoch = 1
        if self.mode == 'train':
            self.writer = SummaryWriter(comment='-' + self.dataset + '_' +
                                        self.model.__class__.__name__ + '_' +
                                        args.loss)
        self.init_eval = args.init_eval
예제 #12
0
def main():
    config = utils.parse_args()

    if config['cuda'] and torch.cuda.is_available():
        device = 'cuda:0'
    else:
        device = 'cpu'

    dataset_args = (config['task'], config['dataset'], config['dataset_path'],
                    config['num_layers'], config['self_loop'],
                    config['normalize_adj'])
    dataset = utils.get_dataset(dataset_args)

    input_dim, output_dim = dataset.get_dims()
    adj, features, labels, idx_train, idx_val, idx_test = dataset.get_data()
    x = features
    y_train = labels[idx_train]
    y_val = labels[idx_val]
    y_test = labels[idx_test]

    model = models.GCN(input_dim, config['hidden_dims'], output_dim,
                       config['dropout'])
    model.to(device)

    if not config['load']:
        criterion = utils.get_criterion(config['task'])
        optimizer = optim.Adam(model.parameters(),
                               lr=config['lr'],
                               weight_decay=config['weight_decay'])
        epochs = config['epochs']
        model.train()
        print('--------------------------------')
        print('Training.')
        for epoch in range(epochs):
            optimizer.zero_grad()
            scores = model(x, adj)[idx_train]
            loss = criterion(scores, y_train)
            loss.backward()
            optimizer.step()
            predictions = torch.max(scores, dim=1)[1]
            num_correct = torch.sum(predictions == y_train).item()
            accuracy = num_correct / len(y_train)
            print('    Training epoch: {}, loss: {:.3f}, accuracy: {:.2f}'.
                  format(epoch + 1, loss.item(), accuracy))
        print('Finished training.')
        print('--------------------------------')

        if config['save']:
            print('--------------------------------')
            directory = os.path.join(os.path.dirname(os.getcwd()),
                                     'trained_models')
            if not os.path.exists(directory):
                os.makedirs(directory)
            fname = utils.get_fname(config)
            path = os.path.join(directory, fname)
            print('Saving model at {}'.format(path))
            torch.save(model.state_dict(), path)
            print('Finished saving model.')
            print('--------------------------------')

    if config['load']:
        directory = os.path.join(os.path.dirname(os.getcwd()),
                                 'trained_models')
        fname = utils.get_fname(config)
        path = os.path.join(directory, fname)
        model.load_state_dict(torch.load(path))
    model.eval()
    print('--------------------------------')
    print('Testing.')
    scores = model(x, adj)[idx_test]
    predictions = torch.max(scores, dim=1)[1]
    num_correct = torch.sum(predictions == y_test).item()
    accuracy = num_correct / len(y_test)
    print('    Test accuracy: {}'.format(accuracy))
    print('Finished testing.')
    print('--------------------------------')