def train(args, exe, train_prog, agent, train_data_list, epoch_id): """Model training for one epoch and log the average loss.""" collate_fn = MoleculeCollateFunc(agent.graph_wrapper, task_type='cls', num_cls_tasks=args.num_tasks, with_graph_label=True, with_pos_neg_mask=False) data_loader = Dataloader(train_data_list, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True, collate_fn=collate_fn) total_data, trained_data = len(train_data_list), 0 list_loss = [] for batch_id, feed_dict in enumerate(data_loader): train_loss = exe.run(train_prog, feed=feed_dict, fetch_list=[agent.loss]) train_loss = np.array(train_loss).mean() list_loss.append(train_loss) trained_data += feed_dict['graph/num_graph'][0] if batch_id % args.log_interval == 0: logging.info( '%s Epoch %d [%d/%d] train/loss:%f' % \ (args.exp, epoch_id, trained_data, total_data, train_loss)) logging.info('%s Epoch %d train/loss:%f' % \ (args.exp, epoch_id, np.mean(list_loss))) sys.stdout.flush() return np.mean(list_loss)
def evaluate(args, exe, test_prog, agent, test_data_list, epoch_id): """Evaluate the model on test dataset.""" collate_fn = MoleculeCollateFunc(agent.graph_wrapper, task_type='cls', num_cls_tasks=args.num_tasks, with_graph_label=True, with_pos_neg_mask=False) data_loader = Dataloader(test_data_list, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False, collate_fn=collate_fn) total_data, eval_data = len(test_data_list), 0 total_pred, total_label, total_valid = [], [], [] for batch_id, feed_dict in enumerate(data_loader): pred, = exe.run(test_prog, feed=feed_dict, fetch_list=[agent.pred], return_numpy=False) total_pred.append(np.array(pred)) total_label.append(feed_dict['label']) total_valid.append(feed_dict['valid']) total_pred = np.concatenate(total_pred, 0) total_label = np.concatenate(total_label, 0) total_valid = np.concatenate(total_valid, 0) return calc_rocauc_score(total_label, total_pred, total_valid)
def train(args, exe, train_prog, agent, train_data_list, epoch_id): collate_fn = MoleculeCollateFunc( agent.graph_wrapper, task_type='cls', with_graph_label=False, # for unsupervised learning with_pos_neg_mask=True) data_loader = Dataloader(train_data_list, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True, collate_fn=collate_fn) total_data, trained_data = len(train_data_list), 0 list_loss = [] for batch_id, feed_dict in enumerate(data_loader): train_loss = exe.run(train_prog, feed=feed_dict, fetch_list=[agent.loss]) train_loss = np.array(train_loss).mean() list_loss.append(train_loss) trained_data += feed_dict['graph/num_graph'][0] if batch_id % args.log_interval == 0: logging.info('Epoch %d [%d/%d] train/loss:%f' % \ (epoch_id, trained_data, total_data, train_loss)) if not args.is_fleet or fleet.worker_index() == 0: logging.info('Epoch %d train/loss:%f' % (epoch_id, np.mean(list_loss))) sys.stdout.flush()
def save_embedding(args, exe, test_prog, agent, data_list, epoch_id): collate_fn = MoleculeCollateFunc( agent.graph_wrapper, task_type='cls', with_graph_label=True, # save emb & label for supervised learning with_pos_neg_mask=True) data_loader = Dataloader(data_list, batch_size=args.batch_size, num_workers=1, shuffle=False, collate_fn=collate_fn) emb, y = agent.encoder.get_embeddings(data_loader, exe, test_prog, agent.graph_emb) emb, y = emb[:len(data_list)], y[:len(data_list)] merge_data = {'emb': emb, 'y': y} with open('%s/epoch_%s.pkl' % (args.emb_dir, epoch_id), 'wb') as f: pickle.dump(merge_data, f)