Beispiel #1
0
    def test_IterDataset(self):
        config = {
            'batch_size': 3,
            'drop_last': True,
            'num_workers': 2,
        }
        collate_fn = Collate_fn(config)
        ds = IterDataset()
        loader = Dataloader(ds,
                            batch_size=config['batch_size'],
                            drop_last=config['drop_last'],
                            num_workers=config['num_workers'],
                            collate_fn=collate_fn)

        epochs = 1
        for e in range(epochs):
            res = []
            for batch_data in loader:
                res.extend(batch_data['data'])
                self.assertEqual(len(batch_data['data']), config['batch_size'])

        # test shuffle
        loader = Dataloader(ds,
                            batch_size=3,
                            drop_last=False,
                            num_workers=1,
                            collate_fn=collate_fn)

        for e in range(epochs):
            res = []
            for batch_data in loader:
                res.extend(batch_data['data'])
            self.assertEqual(set([i for i in range(DATA_SIZE)]), set(res))
Beispiel #2
0
def train(args, exe, train_prog, agent, train_data_list, epoch_id):
    """Model training for one epoch and log the average loss."""
    collate_fn = MoleculeCollateFunc(agent.graph_wrapper,
                                     task_type='cls',
                                     num_cls_tasks=args.num_tasks,
                                     with_graph_label=True,
                                     with_pos_neg_mask=False)
    data_loader = Dataloader(train_data_list,
                             batch_size=args.batch_size,
                             num_workers=args.num_workers,
                             shuffle=True,
                             collate_fn=collate_fn)

    total_data, trained_data = len(train_data_list), 0
    list_loss = []
    for batch_id, feed_dict in enumerate(data_loader):
        train_loss = exe.run(train_prog,
                             feed=feed_dict,
                             fetch_list=[agent.loss])
        train_loss = np.array(train_loss).mean()
        list_loss.append(train_loss)
        trained_data += feed_dict['graph/num_graph'][0]

        if batch_id % args.log_interval == 0:
            logging.info(
                '%s Epoch %d [%d/%d] train/loss:%f' % \
                (args.exp, epoch_id, trained_data, total_data, train_loss))

    logging.info('%s Epoch %d train/loss:%f' % \
                 (args.exp, epoch_id, np.mean(list_loss)))
    sys.stdout.flush()
    return np.mean(list_loss)
Beispiel #3
0
def evaluate(args, exe, test_prog, agent, test_data_list, epoch_id):
    """Evaluate the model on test dataset."""
    collate_fn = MoleculeCollateFunc(agent.graph_wrapper,
                                     task_type='cls',
                                     num_cls_tasks=args.num_tasks,
                                     with_graph_label=True,
                                     with_pos_neg_mask=False)
    data_loader = Dataloader(test_data_list,
                             batch_size=args.batch_size,
                             num_workers=args.num_workers,
                             shuffle=False,
                             collate_fn=collate_fn)

    total_data, eval_data = len(test_data_list), 0
    total_pred, total_label, total_valid = [], [], []
    for batch_id, feed_dict in enumerate(data_loader):
        pred, = exe.run(test_prog,
                        feed=feed_dict,
                        fetch_list=[agent.pred],
                        return_numpy=False)
        total_pred.append(np.array(pred))
        total_label.append(feed_dict['label'])
        total_valid.append(feed_dict['valid'])

    total_pred = np.concatenate(total_pred, 0)
    total_label = np.concatenate(total_label, 0)
    total_valid = np.concatenate(total_valid, 0)
    return calc_rocauc_score(total_label, total_pred, total_valid)
def train(args, exe, train_prog, agent, train_data_list, epoch_id):
    collate_fn = MoleculeCollateFunc(
        agent.graph_wrapper,
        task_type='cls',
        with_graph_label=False,  # for unsupervised learning
        with_pos_neg_mask=True)
    data_loader = Dataloader(train_data_list,
                             batch_size=args.batch_size,
                             num_workers=args.num_workers,
                             shuffle=True,
                             collate_fn=collate_fn)

    total_data, trained_data = len(train_data_list), 0
    list_loss = []
    for batch_id, feed_dict in enumerate(data_loader):
        train_loss = exe.run(train_prog,
                             feed=feed_dict,
                             fetch_list=[agent.loss])
        train_loss = np.array(train_loss).mean()
        list_loss.append(train_loss)
        trained_data += feed_dict['graph/num_graph'][0]

        if batch_id % args.log_interval == 0:
            logging.info('Epoch %d [%d/%d] train/loss:%f' % \
                         (epoch_id, trained_data, total_data, train_loss))

    if not args.is_fleet or fleet.worker_index() == 0:
        logging.info('Epoch %d train/loss:%f' % (epoch_id, np.mean(list_loss)))
        sys.stdout.flush()
Beispiel #5
0
def infer(args):
    log.info("loading data")
    raw_dataset = GraphPropPredDataset(name=args.dataset_name)
    args.num_class = raw_dataset.num_tasks
    args.eval_metric = raw_dataset.eval_metric
    args.task_type = raw_dataset.task_type

    test_ds = MolDataset(args, raw_dataset, mode="test")

    fn = MgfCollateFn(args, mode="test")

    test_loader = Dataloader(test_ds,
                             batch_size=args.batch_size,
                             num_workers=1,
                             collate_fn=fn)
    test_loader = PDataset.from_generator_func(test_loader)

    est = propeller.Learner(MgfModel, args, args.model_config)

    mgf_list = []
    for soft_mgf in est.predict(test_loader,
                                ckpt_path=args.model_path_for_infer,
                                split_batch=True):
        mgf_list.append(soft_mgf)

    mgf = np.concatenate(mgf_list)
    log.info("saving features")
    np.save(
        "dataset/%s/soft_mgf_feat.npy" % (args.dataset_name.replace("-", "_")),
        mgf)
Beispiel #6
0
    def iter_batch(self,
                   batch_size,
                   num_workers=4,
                   shuffle=False,
                   collate_fn=None):
        """
        It returns an batch iterator which yields a batch of data. Firstly, a sub-list of
        `data` of size ``batch_size`` will be draw from the ``data_list``, then 
        the function ``collate_fn`` will be applied to the sub-list to create a batch and 
        yield back. This process is accelerated by multiprocess.

        Args:
            batch_size(int): the batch_size of the batch data of each yield.
            num_workers(int): the number of workers used to generate batch data. Required by 
                multiprocess.
            shuffle(bool): whether to shuffle the order of the ``data_list``.
            collate_fn(function): used to convert the sub-list of ``data_list`` to the 
                aggregated batch data.

        Yields:
            the batch data processed by ``collate_fn``.
        """
        return Dataloader(self,
                          batch_size=batch_size,
                          num_workers=num_workers,
                          shuffle=shuffle,
                          collate_fn=collate_fn)
    def iter_batch(self, batch_size, num_workers=4, shuffle_size=1000, collate_fn=None):
        """tbd"""
        class _TempDataset(PglStreamDataset):
            def __init__(self, data_generator):
                self.data_generator = data_generator
            def __iter__(self):
                for data in self.data_generator:
                    yield data

        return Dataloader(_TempDataset(self.data_generator), 
                batch_size=batch_size, 
                num_workers=num_workers, 
                stream_shuffle_size=shuffle_size,
                collate_fn=collate_fn)
def save_embedding(args, exe, test_prog, agent, data_list, epoch_id):
    collate_fn = MoleculeCollateFunc(
        agent.graph_wrapper,
        task_type='cls',
        with_graph_label=True,  # save emb & label for supervised learning
        with_pos_neg_mask=True)
    data_loader = Dataloader(data_list,
                             batch_size=args.batch_size,
                             num_workers=1,
                             shuffle=False,
                             collate_fn=collate_fn)

    emb, y = agent.encoder.get_embeddings(data_loader, exe, test_prog,
                                          agent.graph_emb)
    emb, y = emb[:len(data_list)], y[:len(data_list)]
    merge_data = {'emb': emb, 'y': y}
    with open('%s/epoch_%s.pkl' % (args.emb_dir, epoch_id), 'wb') as f:
        pickle.dump(merge_data, f)
Beispiel #9
0
def train(args, exe, train_program, model, train_dataset):
    label_name = 'KIBA' if args.use_kiba_label else 'Log10_Kd'
    collate_fn = DTACollateFunc(model.compound_graph_wrapper,
                                is_inference=False,
                                label_name=label_name)
    data_loader = Dataloader(train_dataset,
                             batch_size=args.batch_size,
                             num_workers=args.num_workers,
                             stream_shuffle_size=1000,
                             collate_fn=collate_fn)

    list_loss = []
    for feed_dict in data_loader:
        train_loss, = exe.run(train_program,
                              feed=feed_dict,
                              fetch_list=[model.loss],
                              return_numpy=False)
        list_loss.append(np.array(train_loss).mean())
    return np.mean(list_loss)
Beispiel #10
0
    def test_ListDataset_Order(self):
        config = {
            'batch_size': 2,
            'drop_last': False,
            'shuffle': False,
            'num_workers': 4,
        }
        collate_fn = Collate_fn(config)
        ds = ListDataset()

        # test batch_size
        loader = Dataloader(ds,
                            batch_size=config['batch_size'],
                            drop_last=config['drop_last'],
                            num_workers=config['num_workers'],
                            collate_fn=collate_fn)

        epochs = 1
        for e in range(epochs):
            res = []
            for batch_data in loader:
                res.extend(batch_data['data'])
            self.assertEqual([i for i in range(DATA_SIZE)], res)
Beispiel #11
0
def evaluate(args,
             exe,
             test_program,
             model,
             test_dataset,
             best_mse,
             val_dataset=None):
    """tbd"""
    if args.use_val:
        assert val_dataset is not None

    label_name = 'KIBA' if args.use_kiba_label else 'Log10_Kd'
    collate_fn = DTACollateFunc(model.compound_graph_wrapper,
                                is_inference=False,
                                label_name=label_name)
    data_loader = Dataloader(test_dataset if not args.use_val else val_dataset,
                             batch_size=args.batch_size,
                             num_workers=1,
                             collate_fn=collate_fn)

    if args.use_val:
        test_dataloader = Dataloader(test_dataset,
                                     batch_size=args.batch_size,
                                     num_workers=1,
                                     collate_fn=collate_fn)

    total_n, processed = len(test_dataset), 0
    total_pred, total_label = [], []
    for idx, feed_dict in enumerate(data_loader):
        logging.info('Evaluated {}/{}'.format(processed, total_n))
        pred, = exe.run(test_program,
                        feed=feed_dict,
                        fetch_list=[model.pred],
                        return_numpy=False)
        total_pred.append(np.array(pred))
        total_label.append(feed_dict['label'])
        processed += total_pred[-1].shape[0]

    logging.info('Evaluated {}/{}'.format(processed, total_n))
    total_pred = np.concatenate(total_pred, 0).flatten()
    total_label = np.concatenate(total_label, 0).flatten()
    mse = ((total_label - total_pred)**2).mean(axis=0)

    test_mse, test_ci, ci = None, None, None
    if mse < best_mse and not args.use_val:
        # Computing CI is time consuming
        ci = concordance_index(total_label, total_pred)
    elif mse < best_mse and args.use_val:
        total_pred, total_label = [], []
        for idx, feed_dict in enumerate(test_dataloader):
            pred, = exe.run(test_program,
                            feed=feed_dict,
                            fetch_list=[model.pred],
                            return_numpy=False)
            total_pred.append(np.array(pred))
            total_label.append(feed_dict['label'])

        total_pred = np.concatenate(total_pred, 0).flatten()
        total_label = np.concatenate(total_label, 0).flatten()
        test_mse = ((total_label - total_pred)**2).mean(axis=0)
        test_ci = concordance_index(total_label, total_pred)

    if args.use_val:
        # `mse` aka `val_mse`
        # when `val_mse` > `best_mse`, test_mse = None, test_ci = None
        return mse, test_mse, test_ci
    else:
        return mse, ci
Beispiel #12
0
if args.load_epoch > 0:
    paddle.load(args.save_dir + "/model.iter-" + str(args.load_epoch))

scheduler = paddle.optimizer.lr.ExponentialDecay(learning_rate=config['lr'],
                                                 gamma=config['anneal_rate'],
                                                 verbose=True)
clip = paddle.nn.ClipGradByNorm(clip_norm=config['clip_norm'])
optimizer = paddle.optimizer.Adam(parameters=model.parameters(),
                                  learning_rate=scheduler,
                                  grad_clip=clip)

train_dataset = JtnnDataSet(args.train)
collate_fn = JtnnCollateFn(vocab, True)
data_loader = Dataloader(train_dataset,
                         batch_size=args.batch_size,
                         num_workers=args.num_workers,
                         stream_shuffle_size=100,
                         collate_fn=collate_fn)

total_step = args.load_epoch
beta = config['beta']
meters = np.zeros(4)
for epoch in range(args.epoch):
    for batch in data_loader:
        total_step += 1
        res = model(batch, beta)
        loss = res['loss']
        kl_div = res['kl_div']
        wacc = res['word_acc']
        tacc = res['topo_acc']
        sacc = res['assm_acc']
Beispiel #13
0
def train(args, pretrained_model_config=None):
    log.info("loading data")
    raw_dataset = GraphPropPredDataset(name=args.dataset_name)
    args.num_class = raw_dataset.num_tasks
    args.eval_metric = raw_dataset.eval_metric
    args.task_type = raw_dataset.task_type

    train_ds = MolDataset(args, raw_dataset)

    args.eval_steps = math.ceil(len(train_ds) / args.batch_size)
    log.info("Total %s steps (eval_steps) every epoch." % (args.eval_steps))

    fn = MgfCollateFn(args)

    train_loader = Dataloader(train_ds,
                              batch_size=args.batch_size,
                              num_workers=args.num_workers,
                              shuffle=args.shuffle,
                              stream_shuffle_size=args.shuffle_size,
                              collate_fn=fn)

    # for evaluating
    eval_train_loader = train_loader
    eval_train_loader = PDataset.from_generator_func(eval_train_loader)

    train_loader = multi_epoch_dataloader(train_loader, args.epochs)
    train_loader = PDataset.from_generator_func(train_loader)

    if args.warm_start_from is not None:
        # warm start setting
        def _fn(v):
            if not isinstance(v, F.framework.Parameter):
                return False
            if os.path.exists(os.path.join(args.warm_start_from, v.name)):
                return True
            else:
                return False

        ws = propeller.WarmStartSetting(predicate_fn=_fn,
                                        from_dir=args.warm_start_from)
    else:
        ws = None

    def cmp_fn(old, new):
        if old['eval'][args.metrics] - new['eval'][args.metrics] > 0:
            log.info("best %s eval result: %s" % (args.metrics, new['eval']))
            return True
        else:
            return False

    if args.log_id is not None:
        save_best_model = int(args.log_id) == 5
    else:
        save_best_model = True
    best_exporter = propeller.exporter.BestResultExporter(
        args.output_dir, (cmp_fn, save_best_model))

    eval_datasets = {"eval": eval_train_loader}

    propeller.train.train_and_eval(
        model_class_or_model_fn=MgfModel,
        params=pretrained_model_config,
        run_config=args,
        train_dataset=train_loader,
        eval_dataset=eval_datasets,
        warm_start_setting=ws,
        exporters=[best_exporter],
    )