Exemplo n.º 1
0
def train(args, exe, train_prog, agent, train_data_list, epoch_id):
    """Model training for one epoch and log the average loss."""
    collate_fn = MoleculeCollateFunc(agent.graph_wrapper,
                                     task_type='cls',
                                     num_cls_tasks=args.num_tasks,
                                     with_graph_label=True,
                                     with_pos_neg_mask=False)
    data_loader = Dataloader(train_data_list,
                             batch_size=args.batch_size,
                             num_workers=args.num_workers,
                             shuffle=True,
                             collate_fn=collate_fn)

    total_data, trained_data = len(train_data_list), 0
    list_loss = []
    for batch_id, feed_dict in enumerate(data_loader):
        train_loss = exe.run(train_prog,
                             feed=feed_dict,
                             fetch_list=[agent.loss])
        train_loss = np.array(train_loss).mean()
        list_loss.append(train_loss)
        trained_data += feed_dict['graph/num_graph'][0]

        if batch_id % args.log_interval == 0:
            logging.info(
                '%s Epoch %d [%d/%d] train/loss:%f' % \
                (args.exp, epoch_id, trained_data, total_data, train_loss))

    logging.info('%s Epoch %d train/loss:%f' % \
                 (args.exp, epoch_id, np.mean(list_loss)))
    sys.stdout.flush()
    return np.mean(list_loss)
Exemplo n.º 2
0
def evaluate(args, exe, test_prog, agent, test_data_list, epoch_id):
    """Evaluate the model on test dataset."""
    collate_fn = MoleculeCollateFunc(agent.graph_wrapper,
                                     task_type='cls',
                                     num_cls_tasks=args.num_tasks,
                                     with_graph_label=True,
                                     with_pos_neg_mask=False)
    data_loader = Dataloader(test_data_list,
                             batch_size=args.batch_size,
                             num_workers=args.num_workers,
                             shuffle=False,
                             collate_fn=collate_fn)

    total_data, eval_data = len(test_data_list), 0
    total_pred, total_label, total_valid = [], [], []
    for batch_id, feed_dict in enumerate(data_loader):
        pred, = exe.run(test_prog,
                        feed=feed_dict,
                        fetch_list=[agent.pred],
                        return_numpy=False)
        total_pred.append(np.array(pred))
        total_label.append(feed_dict['label'])
        total_valid.append(feed_dict['valid'])

    total_pred = np.concatenate(total_pred, 0)
    total_label = np.concatenate(total_label, 0)
    total_valid = np.concatenate(total_valid, 0)
    return calc_rocauc_score(total_label, total_pred, total_valid)
def train(args, exe, train_prog, agent, train_data_list, epoch_id):
    collate_fn = MoleculeCollateFunc(
        agent.graph_wrapper,
        task_type='cls',
        with_graph_label=False,  # for unsupervised learning
        with_pos_neg_mask=True)
    data_loader = Dataloader(train_data_list,
                             batch_size=args.batch_size,
                             num_workers=args.num_workers,
                             shuffle=True,
                             collate_fn=collate_fn)

    total_data, trained_data = len(train_data_list), 0
    list_loss = []
    for batch_id, feed_dict in enumerate(data_loader):
        train_loss = exe.run(train_prog,
                             feed=feed_dict,
                             fetch_list=[agent.loss])
        train_loss = np.array(train_loss).mean()
        list_loss.append(train_loss)
        trained_data += feed_dict['graph/num_graph'][0]

        if batch_id % args.log_interval == 0:
            logging.info('Epoch %d [%d/%d] train/loss:%f' % \
                         (epoch_id, trained_data, total_data, train_loss))

    if not args.is_fleet or fleet.worker_index() == 0:
        logging.info('Epoch %d train/loss:%f' % (epoch_id, np.mean(list_loss)))
        sys.stdout.flush()
def save_embedding(args, exe, test_prog, agent, data_list, epoch_id):
    collate_fn = MoleculeCollateFunc(
        agent.graph_wrapper,
        task_type='cls',
        with_graph_label=True,  # save emb & label for supervised learning
        with_pos_neg_mask=True)
    data_loader = Dataloader(data_list,
                             batch_size=args.batch_size,
                             num_workers=1,
                             shuffle=False,
                             collate_fn=collate_fn)

    emb, y = agent.encoder.get_embeddings(data_loader, exe, test_prog,
                                          agent.graph_emb)
    emb, y = emb[:len(data_list)], y[:len(data_list)]
    merge_data = {'emb': emb, 'y': y}
    with open('%s/epoch_%s.pkl' % (args.emb_dir, epoch_id), 'wb') as f:
        pickle.dump(merge_data, f)