示例#1
0
    def train(self, sess, config):
        """Trains node classification model or joint node edge model."""
        self._log('Training {} model...'.format(self.model_name))
        self._log('Training parameters : \n ' + format_params(config))
        epochs = config.epochs
        lr = config.lr
        patience = config.patience
        # best_step = self.global_step
        # step for patience
        curr_step = 0
        # best metrics to select model
        best_val_metrics = self._init_best_metrics()
        best_test_metrics = self._init_best_metrics()
        # train the model
        for epoch in range(epochs):
            self.global_step += 1
            sess.run(self.train_ops, feed_dict=self._make_feed_dict('train'))
            train_metrics = self._eval_model(sess, 'train')
            val_metrics = self._eval_model(sess, 'val')
            self._log('Epoch {} : lr = {:.4f} | '.format(epoch, lr) +
                      format_metrics(train_metrics, 'train') +
                      format_metrics(val_metrics, 'val'))
            # write summaries
            train_summary = sess.run(self.train_summary,
                                     self._make_feed_dict('train'))
            val_summary = sess.run(self.val_summary,
                                   self._make_feed_dict('val'))
            self.summary_writer.add_summary(train_summary,
                                            global_step=self.global_step)
            self.summary_writer.add_summary(val_summary,
                                            global_step=self.global_step)
            # save model checkpoint if val acc increased and val loss decreased
            comp = check_improve(best_val_metrics, val_metrics,
                                 self.target_metrics)
            if np.any(comp):
                if np.all(comp):
                    # best_step = self.global_step
                    # save_path = os.path.join(save_dir, 'model')
                    # self.saver.save(sess, save_path, global_step=self.global_step)
                    best_test_metrics = self._eval_model(sess, 'test')
                best_val_metrics = val_metrics
                curr_step = 0
            else:
                curr_step += 1
                if curr_step == patience:
                    self._log('Early stopping')
                    break

        self._log('\n' + '*' * 40 + ' Best model metrics ' + '*' * 40)
        # load best model to evaluate on test set
        # save_path = os.path.join(save_dir, 'model-{}'.format(best_step))
        # self.restore_checkpoint(sess, save_path)
        self._log(format_metrics(best_val_metrics, 'val'))
        self._log(format_metrics(best_test_metrics, 'test'))
        self._log('\n' + '*' * 40 + ' Training done ' + '*' * 40)
示例#2
0
 def restore_checkpoint(self, sess, model_checkpoint=None):
     """Loads model checkpoint if found and computes evaluation metrics."""
     if model_checkpoint is None or not tf.train.checkpoint_exists(
             model_checkpoint):
         self.init_model_weights(sess)
     else:
         self._log(
             'Loading existing model saved at {}'.format(model_checkpoint))
         self.saver.restore(sess, model_checkpoint)
         self.global_step = int(model_checkpoint.split('-')[-1])
         val_metrics = self._eval_model(sess, 'val')
         test_metrics = self._eval_model(sess, 'test')
         self._log(format_metrics(val_metrics, 'val'))
         self._log(format_metrics(test_metrics, 'test'))
示例#3
0
    def eval(self):
        model = self.model
        data = self.data
        args = self.args
        save_dir = self.args.save_dir
        args.np_seed = None


        if not self.best_test_metrics and args.test_prop > 0:
            model.eval()
            self.best_emb = model.encode(data['features'], self.adj_train_enc)
            self.best_test_metrics = model.compute_metrics(self.best_emb, data, 'test')

        ## CLUSTERING EVAL START
        if args.node_cluster == 1:
            metrics_clustering, pred_label = model.eval_cluster(self.best_emb, data, 'all')
            self.best_test_metrics.update(metrics_clustering)
            self.best_val_metrics = self.best_test_metrics
        else:
            metrics_clustering, pred_label = model.eval_cluster(self.best_emb, data, 'all')
            self.best_test_metrics.update(metrics_clustering)
        ## CLUSTERING EVAL END
        logging.info(" ".join(["Val set results:", format_metrics(self.best_val_metrics, 'val')]))
        logging.info(" ".join(["Test set results:", format_metrics(self.best_test_metrics, 'test')]))
        if args.save:
            np.save(os.path.join(save_dir, 'embeddings.npy'), self.best_emb.cpu().detach().numpy())
            if hasattr(model.encoder, 'att_adj'):
                filename = os.path.join(save_dir, args.dataset + '_att_adj.p')
                pickle.dump(model.encoder.att_adj.cpu().to_dense(), open(filename, 'wb'))
                print('Dumped attention adj: ' + filename)

            json.dump(vars(args), open(os.path.join(save_dir, 'config.json'), 'w'))
            torch.save(model.state_dict(), os.path.join(save_dir, 'model.pth'))
            logging.info(f"Saved model in {save_dir}")

        return self.best_emb, pred_label
示例#4
0
def train(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if int(args.double_precision):
        torch.set_default_dtype(torch.float64)
    if int(args.cuda) >= 0:
        torch.cuda.manual_seed(args.seed)
    args.device = 'cuda:' + str(args.cuda) if int(args.cuda) >= 0 else 'cpu'
    args.patience = args.epochs if not args.patience else int(args.patience)
    logging.getLogger().setLevel(logging.INFO)
    if args.save:
        if not args.save_dir:
            dt = datetime.datetime.now()
            date = f"{dt.year}_{dt.month}_{dt.day}"
            models_dir = os.path.join(os.environ['LOG_DIR'], args.task, date)
            save_dir = get_dir_name(models_dir)
        else:
            save_dir = args.save_dir
        logging.basicConfig(level=logging.INFO,
                            handlers=[
                                logging.FileHandler(
                                    os.path.join(save_dir, 'log.txt')),
                                logging.StreamHandler()
                            ])

    logging.info(f'Using: {args.device}')
    logging.info("Using seed {}.".format(args.seed))

    # Load data
    data = load_data(args, os.path.join(os.environ['DATAPATH'], args.dataset))
    args.n_nodes, args.feat_dim = data['features'].shape

    if args.task == 'nc':
        Model = NCModel
        args.n_classes = int(data['labels'].max() + 1)
        logging.info(f'Num classes: {args.n_classes}')
    else:
        args.nb_false_edges = len(data['train_edges_false'])
        args.nb_edges = len(data['train_edges'])
        if args.task == 'lp':
            Model = LPModel

    if not args.lr_reduce_freq:
        args.lr_reduce_freq = args.epochs

    # Model and optimizer
    model = Model(args)
    logging.info(str(model))
    optimizer = getattr(optimizers,
                        args.optimizer)(params=model.parameters(),
                                        lr=args.lr,
                                        weight_decay=args.weight_decay)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                   step_size=int(
                                                       args.lr_reduce_freq),
                                                   gamma=float(args.gamma))
    tot_params = sum([np.prod(p.size()) for p in model.parameters()])
    logging.info(f"Total number of parameters: {tot_params}")
    if args.cuda is not None and int(args.cuda) >= 0:
        os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda)
        model = model.to(args.device)
        for x, val in data.items():
            if torch.is_tensor(data[x]):
                data[x] = data[x].to(args.device)

    # Train model
    t_total = time.time()
    counter = 0
    best_val_metrics = model.init_metric_dict()
    best_test_metrics = None
    best_emb = None
    for epoch in range(args.epochs):
        t = time.time()
        model.train()
        optimizer.zero_grad()
        embeddings = model.encode(data['features'], data['adj_train_norm'])
        train_metrics = model.compute_metrics(embeddings, data, 'train', args)
        train_metrics['loss'].backward()

        if args.grad_clip is not None:
            max_norm = float(args.grad_clip)
            all_params = list(model.parameters())
            for param in all_params:
                torch.nn.utils.clip_grad_norm_(param, max_norm)
        optimizer.step()
        lr_scheduler.step()
        if (epoch + 1) % args.log_freq == 0:
            logging.info(" ".join([
                'Epoch: {:04d}'.format(epoch + 1),
                'lr: {}'.format(lr_scheduler.get_lr()[0]),
                format_metrics(train_metrics, 'train'),
                'time: {:.4f}s'.format(time.time() - t)
            ]))
        if (epoch + 1) % args.eval_freq == 0:
            model.eval()
            embeddings = model.encode(data['features'], data['adj_train_norm'])
            val_metrics = model.compute_metrics(embeddings, data, 'val', args)
            if (epoch + 1) % args.log_freq == 0:
                logging.info(" ".join([
                    'Epoch: {:04d}'.format(epoch + 1),
                    format_metrics(val_metrics, 'val')
                ]))
            if model.has_improved(best_val_metrics, val_metrics):
                best_test_metrics = model.compute_metrics(
                    embeddings, data, 'test', args)
                if isinstance(embeddings, tuple):
                    best_emb = torch.cat(
                        (pmath.logmap0(embeddings[0], c=1.0), embeddings[1]),
                        dim=1).cpu()
                else:
                    best_emb = embeddings.cpu()
                if args.save:
                    np.save(os.path.join(save_dir, 'embeddings.npy'),
                            best_emb.detach().numpy())

                best_val_metrics = val_metrics
                counter = 0
            else:
                counter += 1
                if counter == args.patience and epoch > args.min_epochs:
                    logging.info("Early stopping")
                    break

    logging.info("Optimization Finished!")
    logging.info("Total time elapsed: {:.4f}s".format(time.time() - t_total))
    if not best_test_metrics:
        model.eval()
        best_emb = model.encode(data['features'], data['adj_train_norm'])
        best_test_metrics = model.compute_metrics(best_emb, data, 'test', args)
    logging.info(" ".join(
        ["Val set results:",
         format_metrics(best_val_metrics, 'val')]))
    logging.info(" ".join(
        ["Test set results:",
         format_metrics(best_test_metrics, 'test')]))

    if args.save:
        if isinstance(best_emb, tuple):
            best_emb = torch.cat(
                (pmath.logmap0(best_emb[0], c=1.0), best_emb[1]), dim=1).cpu()
        else:
            best_emb = best_emb.cpu()
        np.save(os.path.join(save_dir, 'embeddings.npy'),
                best_emb.detach().numpy())
        if hasattr(model.encoder, 'att_adj'):
            filename = os.path.join(save_dir, args.dataset + '_att_adj.p')
            pickle.dump(model.encoder.att_adj.cpu().to_dense(),
                        open(filename, 'wb'))
            print('Dumped attention adj: ' + filename)

        json.dump(vars(args), open(os.path.join(save_dir, 'config.json'), 'w'))
        torch.save(model.state_dict(), os.path.join(save_dir, 'model.pth'))
        logging.info(f"Saved model in {save_dir}")
示例#5
0
    if (epoch + 1) % args.eval_freq == 0:
        model.eval()
        embeddings = model.encode(data['features'], data['adj_train_norm'])
        val_metrics = model.compute_metrics(embeddings, data, 'val')
        if model.has_improved(best_val_metrics, val_metrics):
            best_test_metrics = model.compute_metrics(embeddings, data, 'test')
            best_pred_metrics = model.compute_metrics(embeddings, data, 'pred')
            best_emb = embeddings.cpu()
            best_val_metrics = val_metrics
            counter = 0
        else:
            counter += 1
            if counter == args.patience and epoch > args.min_epochs:
                print("Early stopping")
                break

print("Optimization Finished!")
print("Total time elapsed: {:.4f}s".format(time.time() - t_total))
if not best_test_metrics:
    model.eval()
    best_emb = model.encode(data['features'], data['adj_train_norm'])
    best_test_metrics = model.compute_metrics(best_emb, data, 'test')
    best_pred_metrics = model.compute_metrics(best_emb, data, 'pred')
print(" ".join(["Val set results:", format_metrics(best_val_metrics, 'val')]))
print(" ".join(
    ["Test set results:",
     format_metrics(best_test_metrics, 'test')]))
print(" ".join(
    ["Pred set results:",
     format_metrics(best_pred_metrics, 'pred')]))
示例#6
0
def train(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if int(args.double_precision):
        torch.set_default_dtype(torch.float64)
    if int(args.cuda) >= 0:
        torch.cuda.manual_seed(args.seed)
    args.device = 'cuda:' + str(args.cuda) if int(args.cuda) >= 0 else 'cpu'
    args.patience = args.epochs if not args.patience else int(args.patience)
    logging.getLogger().setLevel(logging.INFO)
    if args.save:
        if not args.save_dir:
            dt = datetime.datetime.now()
            date = f"{dt.year}_{dt.month}_{dt.day}"
            models_dir = os.path.join(os.environ['LOG_DIR'], args.task, date)
            save_dir = get_dir_name(models_dir)
        else:
            save_dir = args.save_dir
        logging.basicConfig(level=logging.INFO,
                            handlers=[
                                logging.FileHandler(
                                    os.path.join(save_dir, 'log.txt')),
                                logging.StreamHandler()
                            ])

    logging.info(f'Using: {args.device}')
    logging.info("Using seed {}.".format(args.seed))

    reserve_mark = 0

    if args.task == 'nc':
        reserve_mark = 0
    else:
        args.task = 'nc'
        reserve_mark = 1
    # Load data
    data = load_data(args, os.path.join('data/', args.dataset))
    args.n_nodes, args.feat_dim = data['features'].shape
    if args.task == 'nc':
        Model = ADVNCModel
        args.n_classes = int(data['labels'].max() + 1)
        logging.info(f'Num classes: {args.n_classes}')
    else:
        args.nb_false_edges = len(data['train_edges_false'])
        args.nb_edges = len(data['train_edges'])
        if args.task == 'lp':
            Model = ADVLPModel
        else:
            Model = RECModel
            # No validation for reconstruction task
            args.eval_freq = args.epochs + 1

    #transfer loading
    if reserve_mark == 1:
        args.task = 'lp'
        # reset reserve mark
        reserve_mark = 0

    if args.task == 'lp':
        reserve_mark = 0
    else:
        args.task = 'lp'
        reserve_mark = 1

    data1 = load_data(args, os.path.join('data/', args.dataset))
    args.n_nodes, args.feat_dim = data1['features'].shape
    if args.task == 'nc':
        Model = ADVNCModel
        args.n_classes = int(data1['labels'].max() + 1)
        logging.info(f'Num classes: {args.n_classes}')
    else:
        print('*****')
        args.nb_false_edges = len(data1['train_edges_false'])
        args.nb_edges = len(data1['train_edges'])
        if args.task == 'lp':
            Model = ADVLPModel
        else:
            Model = RECModel
            # No validation for reconstruction task
            args.eval_freq = args.epochs + 1

    if reserve_mark == 1:
        args.task = 'nc'

    if args.task == 'nc':
        Model = ADVNCModel
    else:
        Model = ADVLPModel

    if not args.lr_reduce_freq:
        args.lr_reduce_freq = args.epochs

    # Model and optimizer
    model = Model(args)
    logging.info(str(model))
    optimizer = getattr(optimizers,
                        args.optimizer)(params=model.parameters(),
                                        lr=args.lr,
                                        weight_decay=args.weight_decay)
    optimizer_en = getattr(optimizers,
                           args.optimizer)(params=model.encoder.parameters(),
                                           lr=args.lr,
                                           weight_decay=args.weight_decay)
    optimizer_de = getattr(optimizers,
                           args.optimizer)(params=model.decoder.parameters(),
                                           lr=args.lr,
                                           weight_decay=args.weight_decay)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                   step_size=int(
                                                       args.lr_reduce_freq),
                                                   gamma=float(args.gamma))
    lr_scheduler_en = torch.optim.lr_scheduler.StepLR(optimizer_en,
                                                      step_size=int(
                                                          args.lr_reduce_freq),
                                                      gamma=float(args.gamma))
    lr_scheduler_de = torch.optim.lr_scheduler.StepLR(optimizer_de,
                                                      step_size=int(
                                                          args.lr_reduce_freq),
                                                      gamma=float(args.gamma))
    tot_params = sum([np.prod(p.size()) for p in model.parameters()])
    logging.info(f"Total number of parameters: {tot_params}")
    if args.cuda is not None and int(args.cuda) >= 0:
        os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda)
        model = model.to(args.device)
        for x, val in data.items():
            if torch.is_tensor(data[x]):
                data[x] = data[x].to(args.device)
        for x, val in data1.items():
            if torch.is_tensor(data1[x]):
                data1[x] = data1[x].to(args.device)
    # Train model
    t_total = time.time()
    counter = 0
    best_val_metrics = model.init_metric_dict()
    best_test_metrics = None
    best_emb = None
    for epoch in range(args.epochs):
        t = time.time()
        model.train()
        model.mark = 1
        # if epoch%3==0:

        # model.save_net()
        # model.save_emb()
        # lr_scheduler.step()

        # if epoch%3==1:
        #     if epoch > 100:
        optimizer.zero_grad()
        optimizer_en.zero_grad()
        # model.load_emb()
        embeddings1 = model.encode(data1['features'], data1['adj_train_norm'])
        train_metrics1 = model.compute_metrics1(embeddings1, data1, 'train')
        loss1 = 10 * train_metrics1['loss']
        loss1.backward()
        if args.grad_clip is not None:
            max_norm = float(args.grad_clip)
            all_params = list(model.parameters())
            for param in all_params:
                torch.nn.utils.clip_grad_norm_(param, max_norm)
        optimizer.step()
        # model.load_net()
        #     # lr_scheduler.step()
        #     if (args.model == 'GCN' or args.model == 'GAT' and epoch > 100) or (args.model == 'HGCN' and epoch > 1000):
        optimizer_en.zero_grad()
        embeddings = model.encode(data['features'], data['adj_train_norm'])
        train_metrics = model.compute_metrics(embeddings, data, 'train')
        loss = -(train_metrics['loss'])  # - train_metrics['loss_shuffle'])
        loss.backward()
        optimizer_en.step()
        #
        # # if epoch%3==2:
        optimizer_de.zero_grad()
        embeddings2 = model.encode(data['features'],
                                   data['adj_train_norm']).detach_()
        train_metrics2 = model.compute_metrics(embeddings2, data, 'train')
        loss2 = (train_metrics2['loss'])  # - train_metrics2['loss_shuffle'])
        loss2.backward()
        optimizer_de.step()
        lr_scheduler.step()
        lr_scheduler_en.step()
        lr_scheduler_de.step()
        # if epoch<100:
        #     train_metrics2 = train_metrics
        if (epoch + 1) % args.log_freq == 0:
            logging.info(" ".join([
                'Epoch: {:04d}'.format(epoch + 1),
                'lr: {}'.format(lr_scheduler.get_lr()[0]),
                format_metrics(train_metrics1, 'train'),
                format_metrics(train_metrics2, 'train'),
                'time: {:.4f}s'.format(time.time() - t)
            ]))
            if not best_val_metrics == None:
                logging.info(" ".join([
                    "Val set results:",
                    format_metrics(best_val_metrics, 'val')
                ]))
                logging.info(" ".join([
                    "Val set results:",
                    format_metrics(best_val_metrics1, 'val')
                ]))
                logging.info(" ".join([
                    "Test set results:",
                    format_metrics(best_test_metrics, 'test')
                ]))
                logging.info(" ".join([
                    "Test set results:",
                    format_metrics(best_test_metrics1, 'test')
                ]))
        if (epoch + 1) % args.eval_freq == 0:
            model.eval()
            embeddings = model.encode(data['features'], data['adj_train_norm'])
            val_metrics = model.compute_metrics(embeddings, data, 'val')
            embeddings1 = model.encode(data1['features'],
                                       data1['adj_train_norm'])
            val_metrics1 = model.compute_metrics1(embeddings1, data1, 'val')
            if (epoch + 1) % args.log_freq == 0:
                logging.info(" ".join([
                    'Epoch: {:04d}'.format(epoch + 1),
                    format_metrics(val_metrics, 'val'),
                    format_metrics(val_metrics1, 'val')
                ]))

            embeddings = model.encode(data['features'], data['adj_train_norm'])
            test_metrics = model.compute_metrics(embeddings, data, 'test')
            embeddings1 = model.encode(data1['features'],
                                       data1['adj_train_norm'])
            test_metrics1 = model.compute_metrics1(embeddings1, data1, 'test')
            if (epoch + 1) % args.log_freq == 0:
                logging.info(" ".join([
                    'Epoch: {:04d}'.format(epoch + 1),
                    format_metrics(test_metrics, 'test'),
                    format_metrics(test_metrics1, 'test')
                ]))
            if model.has_improved(best_val_metrics, val_metrics1):
                best_test_metrics = model.compute_metrics1(
                    embeddings1, data1, 'test')
                best_test_metrics1 = model.compute_metrics(
                    embeddings, data, 'test')
                best_emb = embeddings1.cpu()
                if args.save:
                    np.save(os.path.join(save_dir, 'embeddings.npy'),
                            best_emb.detach().numpy())
                best_val_metrics = val_metrics1
                best_val_metrics1 = val_metrics1
                counter = 0
            # else:
            #     counter += 1
            #     if counter == args.patience and epoch > args.min_epochs:
            #         logging.info("Early stopping")
            #         break

    logging.info("Optimization Finished!")
    logging.info("Total time elapsed: {:.4f}s".format(time.time() - t_total))
    if not best_test_metrics:
        model.eval()
        best_emb = model.encode(data['features'], data['adj_train_norm'])
        best_test_metrics = model.compute_metrics(best_emb, data, 'test')
    logging.info(" ".join(
        ["Val set results:",
         format_metrics(best_val_metrics, 'val')]))
    logging.info(" ".join(
        ["Test set results:",
         format_metrics(best_test_metrics, 'test')]))
    if args.save:
        np.save(os.path.join(save_dir, 'embeddings.npy'),
                best_emb.cpu().detach().numpy())
        if hasattr(model.encoder, 'att_adj'):
            filename = os.path.join(save_dir, args.dataset + '_att_adj.p')
            pickle.dump(model.encoder.att_adj.cpu().to_dense(),
                        open(filename, 'wb'))
            print('Dumped attention adj: ' + filename)

        json.dump(vars(args), open(os.path.join(save_dir, 'config.json'), 'w'))
        torch.save(model.state_dict(), os.path.join(save_dir, 'model.pth'))
        logging.info(f"Saved model in {save_dir}")
示例#7
0
def train(args):
    if int(args.double_precision):
        torch.set_default_dtype(torch.float64)
    args.device = 'cuda:' + str(args.cuda) if int(args.cuda) >= 0 else 'cpu'
    args.patience = args.epochs if not args.patience else int(args.patience)
    logging.getLogger().setLevel(logging.INFO)

    logging.info(f'Using: {args.device}')
    logging.info("Using seed {}.".format(args.seed))

    data = load_data(args, os.path.join('./data', args.dataset))
    args.n_nodes, args.feat_dim = data['features'].shape
    if args.task == 'nc':
        Model = NCModel
        args.n_classes = int(data['labels'].max() + 1)
        logging.info(f'Num classes: {args.n_classes}')
    else:
        args.nb_false_edges = len(data['train_edges_false'])
        args.nb_edges = len(data['train_edges'])
        if args.task == 'lp':
            Model = LPModel
        else:
            Model = RECModel
            args.eval_freq = args.epochs + 1

    # Model and optimizer
    model = Model(args)

    optimizer, lr_scheduler, stiefel_optimizer, stiefel_lr_scheduler = \
                        set_up_optimizer_scheduler(False, args, model, args.lr, args.lr_stie)
    tot_params = sum([np.prod(p.size()) for p in model.parameters()])

    if args.cuda is not None and int(args.cuda) >= 0:
        os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda)
        model = model.to(args.device)
        for x, val in data.items():
            if torch.is_tensor(data[x]):
                data[x] = data[x].to(args.device)

    # Train model
    t_total = time.time()
    counter = 0
    best_val_metrics = model.init_metric_dict()
    best_test_metrics = None
    best_emb = None
    for epoch in range(args.epochs):
        t = time.time()
        model.train()
        optimizer.zero_grad()
        stiefel_optimizer.zero_grad()
        embeddings = model.encode(data['features'], data['hgnn_adj'],
                                  data['hgnn_weight'])
        train_metrics = model.compute_metrics(embeddings, data, 'train')
        train_metrics['loss'].backward()
        if args.grad_clip is not None:
            max_norm = float(args.grad_clip)
            all_params = list(model.parameters())
            for param in all_params:
                torch.nn.utils.clip_grad_norm_(param, max_norm)
        optimizer.step()
        stiefel_optimizer.step()
        lr_scheduler.step()
        stiefel_lr_scheduler.step()
        if (epoch + 1) % args.log_freq == 0:
            logging.info(" ".join([
                'Epoch: {:04d}'.format(epoch + 1),
                'lr: {:04f}, stie_lr: {:04f}'.format(
                    lr_scheduler.get_lr()[0],
                    stiefel_lr_scheduler.get_lr()[0]),
                format_metrics(train_metrics, 'train'),
                'time: {:.4f}s'.format(time.time() - t)
            ]))
        if (epoch + 1) % args.eval_freq == 0:
            model.eval()
            embeddings = model.encode(data['features'], data['hgnn_adj'],
                                      data['hgnn_weight'])
            for i in range(embeddings.size(0)):
                if (embeddings[i] != embeddings[i]).sum() > 1:
                    print('PART train  i', i, 'embeddings[i]', embeddings[i])
            val_metrics = model.compute_metrics(embeddings, data, 'val')
            if (epoch + 1) % args.log_freq == 0:
                logging.info(" ".join([
                    'Epoch: {:04d}'.format(epoch + 1),
                    format_metrics(val_metrics, 'val')
                ]))
            if model.has_improved(best_val_metrics, val_metrics):
                best_test_metrics = model.compute_metrics(
                    embeddings, data, 'test')
                best_emb = embeddings.cpu()
                if args.save:
                    np.save(os.path.join(save_dir, 'embeddings.npy'),
                            best_emb.detach().numpy())
                best_val_metrics = val_metrics
                counter = 0
            else:
                counter += 1
                if counter == args.patience and epoch > args.min_epochs:
                    logging.info("Early stopping")
                    break

    logging.info("Optimization Finished!")
    logging.info("Total time elapsed: {:.4f}s".format(time.time() - t_total))

    if args.save:
        np.save(os.path.join(save_dir,
                             str(args.dataset) + '_embeddings.npy'),
                best_emb.cpu().detach().numpy())
        torch.save(
            model.state_dict(),
            os.path.join(
                save_dir,
                str(args.dataset) + 'model_auc' + str(best_rocauc) + '.pth'))
        logging.info(f"Saved model in {save_dir}")
    if args.task == 'lp':
        return best_test_metrics['roc']
    if args.task == 'nc':
        return best_test_metrics['f1']
示例#8
0
    def fit(self):
        args = self.args
        model = self.model
        optimizer = self.optimizer
        data = self.data
        lr_scheduler = self.lr_scheduler
        save_dir = self.args.save_dir

        for epoch in range(args.epochs):
            t = time.time()
            model.train()
            optimizer.zero_grad()
            embeddings = model.encode(data['features'], self.adj_train_enc)
            train_metrics = model.compute_metrics(embeddings, data, 'train', epoch)
            train_metrics['loss'].backward()
            if args.grad_clip is not None:
                max_norm = float(args.grad_clip)
                all_params = list(model.parameters())
                for param in all_params:
                    torch.nn.utils.clip_grad_norm_(param, max_norm)
            optimizer.step()
            lr_scheduler.step()
            with torch.no_grad():
                if (epoch + 1) % args.log_freq == 0:
                    logging.info(" ".join(['Epoch: {:04d}'.format(epoch + 1),
                                           'lr: {}'.format(lr_scheduler.get_lr()[0]),
                                           format_metrics(train_metrics, 'train'),
                                           'time: {:.4f}s'.format(time.time() - t)
                                           ]))
                if (epoch + 1) % args.eval_freq == 0:
                    model.eval()
                    embeddings = model.encode(data['features'], self.adj_train_enc)
                    if args.node_cluster != 1:
                        ## Link Prediction Task that use train/val/test
                        val_metrics = model.compute_metrics(embeddings, data, 'val')
                        if (epoch + 1) % args.log_freq == 0:
                            logging.info(" ".join(['Epoch: {:04d}'.format(epoch + 1), format_metrics(val_metrics, 'val')]))
                        if model.has_improved(self.best_val_metrics, val_metrics):
                            self.best_test_metrics = model.compute_metrics(embeddings, data, 'test')
                            self.best_emb = embeddings
                            if args.save:
                                np.save(os.path.join(save_dir, 'embeddings.npy'), self.best_emb.cpu().detach().numpy())
                            self.best_val_metrics = val_metrics
                            self.best_val_metrics['epoch'] = epoch + 1
                            self.counter = 0
                            # logging.info("improved")
                        else:
                            # logging.info("not improved :"+str(self.counter))
                            self.counter += 1
                            if self.counter >= args.patience and epoch > args.min_epochs:  # NOTE : fixed when improve only epoch0
                                logging.info("Early stopping")
                                break
                    else:
                        ## Node Clustering Task that use 100 % trainset.
                        if self.best_test_metrics.get('loss', 999) > train_metrics['loss']:
                            '''
                            # NOTE : when kmeans calculated, affect np.state and takes time, and not fair to monitor it.
                            # metrics_clustering = model.eval_cluster(embeddings, data, 'all')
                            # logging.info(" ".join(["Cluster results:", format_metrics(metrics_clustering, 'all')]))
                            '''
                            self.best_emb = embeddings
                            logging.info("Best loss found")
                            self.best_test_metrics['loss'] = train_metrics['loss']
                            self.best_test_metrics['epoch'] = epoch + 1
                            self.counter = 0
                        else:
                            self.counter += 1
                            if self.counter >= args.patience and epoch > args.min_epochs:  # NOTE : fixed when improve only epoch0
                                logging.info("Early stopping")
                                break


        logging.info("Optimization Finished!")
        logging.info("Total time elapsed: {:.4f}s".format(time.time() - self.t_total))
示例#9
0
文件: train.py 项目: forsubmit/RAHGNN
def train(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if int(args.double_precision):
        torch.set_default_dtype(torch.float64)
    if int(args.cuda) >= 0:
        torch.cuda.manual_seed(args.seed)
    args.device = 'cuda:' + str(args.cuda) if int(args.cuda) >= 0 else 'cpu'
    args.patience = args.epochs if not args.patience else int(args.patience)
    logging.getLogger().setLevel(logging.INFO)
    if args.save:
        if not args.save_dir:
            dt = datetime.datetime.now()
            date = f"{dt.year}_{dt.month}_{dt.day}"
            models_dir = os.path.join(os.environ['LOG_DIR'], args.task, date)
            save_dir = get_dir_name(models_dir)
        else:
            save_dir = args.save_dir
        logging.basicConfig(level=logging.INFO,
                            handlers=[
                                logging.FileHandler(
                                    os.path.join(save_dir, 'log.txt')),
                                logging.StreamHandler()
                            ])

    logging.info(f'Using: {args.device}')
    logging.info("Using seed {}.".format(args.seed))
    warnings.filterwarnings(action='ignore')

    # Load data
    data = load_data(args, os.path.join(os.environ['DATAPATH'], args.dataset))
    args.n_nodes, args.feat_dim = data['features'].shape
    if args.task == 'nc':
        Model = NCModel
        args.n_classes = int(data['labels'].max() + 1)
        logging.info(f'Num classes: {args.n_classes}')
    else:
        args.nb_false_edges = len(data['train_edges_false'])
        args.nb_edges = len(data['train_edges'])
        if args.task == 'lp':
            Model = LPModel

    if not args.lr_reduce_freq:
        args.lr_reduce_freq = args.epochs

    # Model and optimizer
    # Initialize RL environment
    lr_q = args.lr_q  # RL Learning Rate
    action_space1 = ['reject', 'accept']  # HGNN action space
    action_space2 = [  # ACE  action space
        'r, r', 'r, a', 'a, r', 'a, a'
    ]
    joint_actions = []
    for i in range(len(action_space1)):  # Joint action space
        for j in range(len(action_space2)):
            joint_actions.append((i, j))
    env = Env(theta=args.theta, initial_c=args.c)
    Agent1 = QLearningTable(actions=list(range(len(action_space1))),
                            joint=joint_actions,
                            start=args.start_q,
                            learning_rate=lr_q)
    Agent2 = QLearningTable(actions=list(range(len(action_space2))),
                            joint=joint_actions,
                            start=args.start_q,
                            learning_rate=lr_q)

    hgnn = Model(args)  # Agent1 HGNN
    ace = Model(args)  # Agent2 ACE
    logging.info(str(hgnn))
    optimizer = getattr(optimizers,
                        args.optimizer)(params=hgnn.parameters(),
                                        lr=args.lr,
                                        weight_decay=args.weight_decay)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                   step_size=int(
                                                       args.lr_reduce_freq),
                                                   gamma=float(args.gamma))
    tot_params = sum([np.prod(p.size()) for p in hgnn.parameters()])
    logging.info(f"Total number of parameters: {tot_params}")
    if args.cuda is not None and int(args.cuda) >= 0:
        os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda)
        hgnn = hgnn.to(args.device)
        ace = ace.to(args.device)
        for x, val in data.items():
            if torch.is_tensor(data[x]):
                data[x] = data[x].to(args.device)

    # Train model
    t_total = time.time()
    counter = 0
    best_val_metrics = hgnn.init_metric_dict()
    best_test_metrics = None
    best_emb = None
    val_metric_record = []
    train_metric_record = []
    for epoch in range(args.epochs):
        t = time.time()
        hgnn.train()
        ace.train()
        optimizer.zero_grad()

        # train model with RL and return Agent1's train metrics
        # Terminate mechanism
        if epoch > args.start_q + 30:
            r1 = np.array(env.c1_record)[-30:-1, 0]
            r2 = np.array(env.c1_record)[-30:-1, 1]
            if abs(max(r1) - min(r1)) <= 0.03 and not env.stop[0]:
                env.stop[0] = True
                print("Layer1 RL terminate at {:.3f}.".format(
                    env.c1_record[-1][0]))
                counter = args.patience // 2
            if abs(max(r2) - min(r2)) <= 0.03 and not env.stop[1]:
                env.stop[1] = True
                print("Layer2 RL terminate at {:.3f}.".format(
                    env.c1_record[-1][1]))
                counter = args.patience // 2

        train_metrics = hgnn.train_with_RL(env, Agent1, Agent2, data, epoch,
                                           ace)
        train_metric_record.append(train_metrics[hgnn.key_param])
        train_metrics['loss'].backward()
        if args.grad_clip is not None:
            max_norm = float(args.grad_clip)
            all_params = list(hgnn.parameters())
            for param in all_params:
                torch.nn.utils.clip_grad_norm_(param, max_norm)
        optimizer.step()
        lr_scheduler.step()

        if (epoch + 1) % args.log_freq == 0:
            logging.info(" ".join([
                'Epoch: {:04d}'.format(epoch + 1),
                'lr: {}'.format(lr_scheduler.get_lr()[0]),
                format_metrics(train_metrics, 'train'),
                'time: {:.4f}s'.format(time.time() - t)
            ]))
        if (epoch + 1) % args.eval_freq == 0:
            hgnn.eval()
            embeddings = hgnn.encode(data['features'], data['adj_train_norm'])
            val_metrics = hgnn.compute_metrics(embeddings, data, 'val')
            val_metric_record.append(val_metrics[hgnn.key_param])

            if (epoch + 1) % args.log_freq == 0:
                logging.info(" ".join([
                    'Epoch: {:04d}'.format(epoch + 1),
                    format_metrics(val_metrics, 'val')
                ]))

            if hgnn.has_improved(best_val_metrics, val_metrics):
                best_test_metrics = hgnn.compute_metrics(
                    embeddings, data, 'test')
                best_emb = embeddings.cpu()
                if args.save:
                    np.save(os.path.join(save_dir, 'embeddings.npy'),
                            best_emb.detach().numpy())
                best_val_metrics = val_metrics
                counter = 0
            else:
                counter += 1
                if counter >= args.patience and epoch > args.min_epochs and all(
                        env.stop):
                    logging.info("Early stopping")
                    break

    logging.info("Optimization Finished!")
    logging.info("Total time elapsed: {:.4f}s".format(time.time() - t_total))
    if not best_test_metrics:
        hgnn.eval()
        best_emb = hgnn.encode(data['features'], data['adj_train_norm'])
        best_test_metrics = hgnn.compute_metrics(best_emb, data, 'test')
    logging.info(" ".join(
        ["Val set results:",
         format_metrics(best_val_metrics, 'val')]))
    logging.info(" ".join(
        ["Test set results:",
         format_metrics(best_test_metrics, 'test')]))
    if args.save:
        # Save embeddings and attentions
        np.save(os.path.join(save_dir, 'embeddings.npy'),
                best_emb.cpu().detach().numpy())
        if hasattr(hgnn.encoder, 'att_adj'):
            filename = os.path.join(save_dir, args.dataset + '_att_adj.p')
            pickle.dump(hgnn.encoder.att_adj.cpu().to_dense(),
                        open(filename, 'wb'))
            print('Dumped attention adj: ' + filename)

        # Save model
        json.dump(vars(args), open(os.path.join(save_dir, 'config.json'), 'w'))
        torch.save(hgnn.state_dict(), os.path.join(save_dir, 'model.pth'))

        # Save curvature record and figures
        np.save(os.path.join(save_dir, 'curv1.npy'), np.array(env.c1_record))
        np.save(os.path.join(save_dir, 'curv2.npy'), np.array(env.c2_record))

        # Save acc record
        np.save(os.path.join(save_dir, 'metric_record.npy'),
                np.array([train_metric_record, val_metric_record]))

        logging.info("Agent1: {}, Agent2: {}".format(env.c1, env.c2))
        logging.info(f"Saved model in {save_dir}")
示例#10
0
def train(args):
    if args.dataset in ['disease_lp', 'disease_nc']:
        args.normalize_feats = 0
    if args.task == 'nc':
    	args.num_layers += 1
    if args.manifold == 'Lorentzian' or 'Euclidean':
        args.dim = args.dim + 1
    args.c = float(args.c) if args.c != None else None
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if int(args.double_precision):
        torch.set_default_dtype(torch.float64)
    if int(args.cuda) >= 0:
        torch.cuda.manual_seed(args.seed)
    args.device = 'cuda:' + str(args.cuda) if int(args.cuda) >= 0 else 'cpu'
    args.patience = args.epochs if not args.patience else int(args.patience)
    logging.getLogger().setLevel(logging.INFO)
    if args.save:
        if not args.save_dir:
            dt = datetime.datetime.now()
            date = f"{dt.year}_{dt.month}_{dt.day}"
            models_dir = os.path.join('save', args.task, date)
            save_dir = get_dir_name(models_dir)
            if not os.path.exists(save_dir):
                os.makedirs(save_dir)
        else:
            save_dir = args.save_dir
        logging.basicConfig(level=logging.INFO,
                            handlers=[
                                logging.FileHandler(os.path.join(save_dir, 'log.txt')),
                                logging.StreamHandler()
                            ])

    logging.info(f'Using cuda: {args.cuda}')
    logging.info("Using seed {}.".format(args.seed))
    logging.info("Using dataset {}.".format(args.dataset))

    # Load data
    data = load_data(args, os.path.join('data/', args.dataset))
    args.n_nodes, args.feat_dim = data['features'].shape
    if args.task == 'nc':
        Model = NCModel
        args.n_classes = int(data['labels'].max() + 1)
        logging.info(f'Num classes: {args.n_classes}')
    else:
        args.nb_false_edges = len(data['train_edges_false'])
        args.nb_edges = len(data['train_edges'])
        if args.task == 'lp':
            Model = LPModel
        else:
            raise NotImplementedError
            Model = RECModel
            # No validation for reconstruction task
            args.eval_freq = args.epochs + 1

    if not args.lr_reduce_freq:
        args.lr_reduce_freq = args.epochs

    # Model and optimizer
    model = Model(args)
    if args.cuda is not None and int(args.cuda) >= 0:
        model = model.to(args.device)
        for x, val in data.items():
            if torch.is_tensor(data[x]):
                data[x] = data[x].to(args.device)
    for iter_i in [str(args.run_times)]:
        optimizer = getattr(optimizers, args.optimizer)(params=model.parameters(), lr=args.lr,
                                                    weight_decay=args.weight_decay)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer,
            step_size=int(args.lr_reduce_freq),
            gamma=float(args.gamma)
        )
        # Train model
        t_total = time.time()
        counter = 0
        best_val_metrics = model.init_metric_dict()
        best_test_metrics = None
        best_emb = None
        for epoch in range(args.epochs):
            t = time.time()
            model.train()
            optimizer.zero_grad()
            embeddings = model.encode(data['features'], data['adj_train_norm'])
            train_metrics = model.compute_metrics(embeddings, data, 'train')
            train_metrics['loss'].backward()
            if args.grad_clip is not None:
                max_norm = float(args.grad_clip)
                all_params = list(model.parameters())
                for param in all_params:
                    torch.nn.utils.clip_grad_norm_(param, max_norm)
            optimizer.step()
            lr_scheduler.step()
            if (epoch + 1) % args.log_freq == 0:
                logging.info(" ".join(['{} {} Epoch: {} {:04d}'.format(args.dataset, args.task, iter_i, epoch + 1),
                                       'lr: {}'.format(lr_scheduler.get_lr()[0]),
                                       format_metrics(train_metrics, 'train'),
                                       'time: {:.4f}s'.format(time.time() - t)
                                       ]))
            if (epoch + 1) % args.eval_freq == 0:
                model.eval()
                embeddings = model.encode(data['features'], data['adj_train_norm'])
                val_metrics = model.compute_metrics(embeddings, data, 'val')
                if (epoch + 1) % args.log_freq == 0:
                    logging.info(" ".join(['{} {} Epoch: {} {:04d}'.format(args.dataset, args.task, iter_i, epoch + 1), format_metrics(val_metrics, 'val')]))
                if model.has_improved(best_val_metrics, val_metrics):
                    best_test_metrics = model.compute_metrics(embeddings, data, 'test')
                    best_emb = embeddings.cpu()
                    if args.save:
                        np.save(os.path.join(save_dir, 'embeddings.npy'), best_emb.detach().numpy())
                    best_val_metrics = val_metrics
                    counter = 0
                else:
                    counter += 1
                    if counter == args.patience and epoch > args.min_epochs:
                        logging.info("Early stopping")
                        break

        logging.info("Optimization Finished!")
        logging.info("Total time elapsed: {:.4f}s".format(time.time() - t_total))
        if not best_test_metrics:
            model.eval()
            best_emb = model.encode(data['features'], data['adj_train_norm'])
            best_test_metrics = model.compute_metrics(best_emb, data, 'test')
        logging.info(" ".join(["Val set results:", format_metrics(best_val_metrics, 'val')]))
        logging.info(" ".join(["Test set results:", format_metrics(best_test_metrics, 'test')]))

        if args.save:
            print('Saved path', os.path.join(save_dir, 'embeddings.npy'))
            np.save(os.path.join(save_dir, 'embeddings.npy'), best_emb.cpu().detach().numpy())
            if hasattr(model.encoder, 'att_adj'):
                filename = os.path.join(save_dir, args.dataset + '_att_adj.p')
                pickle.dump(model.encoder.att_adj.cpu().to_dense(), open(filename, 'wb'))
                print('Dumped attention adj: ' + filename)
            json.dump(vars(args), open(os.path.join(save_dir, 'config.json'), 'w'))
            torch.save(model.state_dict(), os.path.join(save_dir, 'model.pth'))
            logging.info(f"Saved model in {save_dir}")