Example #1
0
def train(epochs, batchsize, interval, c_path, s_path, modeldir):
    # Dataset definition
    dataset = CRDataset(c_path, s_path)
    collator = CollateFn()

    # Model definition
    generator = CartoonRenderer()
    generator.cuda()
    generator.train()
    gen_opt = torch.optim.Adam(generator.parameters(), lr=0.0001)

    discriminator = Discriminator()
    discriminator.cuda()
    discriminator.train()
    dis_opt = torch.optim.Adam(discriminator.parameters(), lr=0.0001)

    iterations = 0

    for epoch in range(epochs):
        dataloader = DataLoader(dataset,
                                batch_size=batchsize,
                                shuffle=True,
                                drop_last=True,
                                collate_fn=collator)
        dataloader = tqdm(dataloader)

        for i, data in enumerate(dataloader):
            iterations += 1
            c, s = data

            y, _, _, _ = generator(c, s)
            dis_loss = adversarial_loss_dis(discriminator, y, s)

            dis_opt.zero_grad()
            dis_loss.backward()
            dis_opt.step()

            y, c_feat, sa_list, y_feat = generator(c, s)
            y_c, _, _, _ = generator(c, c)
            y_s, _, _, _ = generator(s, s)

            gen_loss = adversarial_loss_gen(discriminator, y)
            gen_loss += reconstruction_loss(y_c, c)
            gen_loss += reconstruction_loss(y_s, s)
            gen_loss += content_loss(sa_list, y_feat)
            gen_loss += style_loss(c_feat, y_feat)

            gen_opt.zero_grad()
            gen_loss.backward()
            gen_opt.step()

            if iterations % interval == 1:
                torch.save(generator.state_dict(),
                           f"{modeldir}/model_{iterations}.pt")

            print(
                f"iter: {iterations} dis loss: {dis_loss.data} gen loss: {gen_loss.data}"
            )
Example #2
0
def train(epochs, batchsize, interval, c_path, s_path):
    # Dataset definition
    dataset = HairDataset(c_path, s_path)
    collator = CollateFn()

    # Model & Optimizer Definition
    munit = MUNIT()
    munit.cuda()
    munit.train()
    m_opt = torch.optim.Adam(munit.parameters(),
                             lr=0.0001,
                             betas=(0.5, 0.999),
                             weight_decay=0.0001)

    discriminator_a = Discriminator()
    discriminator_a.cuda()
    discriminator_a.train()
    da_opt = torch.optim.Adam(discriminator_a.parameters(),
                              lr=0.0001,
                              betas=(0.5, 0.999),
                              weight_decay=0.0001)

    discriminator_b = Discriminator()
    discriminator_b.cuda()
    discriminator_b.train()
    db_opt = torch.optim.Adam(discriminator_b.parameters(),
                              lr=0.0001,
                              betas=(0.5, 0.999),
                              weight_decay=0.0001)

    vgg = Vgg19Norm()
    vgg.cuda()
    vgg.train()

    iterations = 0

    for epoch in range(epochs):
        dataloader = DataLoader(dataset,
                                batch_size=batchsize,
                                shuffle=True,
                                drop_last=True,
                                collate_fn=collator)
        dataloader = tqdm(dataloader)

        for i, data in enumerate(dataloader):
            iterations += 1
            a, b = data
            _, _, _, _, _, _, ba, ab, _, _, _, _, _, _ = munit(a, b)

            loss = adversarial_dis_loss(discriminator_a, ba, a)
            loss += adversarial_dis_loss(discriminator_b, ab, b)

            da_opt.zero_grad()
            db_opt.zero_grad()
            loss.backward()
            da_opt.step()
            db_opt.step()

            c_a, s_a, c_b, s_b, a_recon, \
                b_recon, ba, ab, c_b_recon, s_a_recon, c_a_recon, s_b_recon, aba, bab = munit(a, b)

            loss = adversarial_gen_loss(discriminator_a, ba)
            loss += adversarial_gen_loss(discriminator_b, ab)
            loss += 10 * reconstruction_loss(a_recon, a)
            loss += 10 * reconstruction_loss(b_recon, b)
            loss += reconstruction_loss(c_a, c_a_recon)
            loss += reconstruction_loss(c_b, c_b_recon)
            loss += reconstruction_loss(s_a, s_a_recon)
            loss += reconstruction_loss(s_b, s_b_recon)
            loss += 10 * reconstruction_loss(aba, a)
            loss += 10 * reconstruction_loss(bab, b)
            loss += perceptual_loss(vgg, ba, b)
            loss += perceptual_loss(vgg, ab, a)

            m_opt.zero_grad()
            loss.backward()
            m_opt.step()

            if iterations % interval == 1:
                torch.save(munit.load_state_dict,
                           f"./modeldir/model_{iterations}.pt")

                pylab.rcParams['figure.figsize'] = (16.0, 16.0)
                pylab.clf()

                munit.eval()

                with torch.no_grad():
                    _, _, _, _, _, _, _, ab, _, _, _, _, _, _ = munit(a, b)
                    fake = ab.detach().cpu().numpy()
                    real = a.detach().cpu().numpy()

                    for i in range(batchsize):
                        tmp = (np.clip(real[i] * 127.5 + 127.5, 0,
                                       255)).transpose(1, 2,
                                                       0).astype(np.uint8)
                        pylab.subplot(4, 4, 2 * i + 1)
                        pylab.imshow(tmp)
                        pylab.axis("off")
                        pylab.savefig(
                            "outdir/visualize_{}.png".format(iterations))
                        tmp = (np.clip(fake[i] * 127.5 + 127.5, 0,
                                       255)).transpose(1, 2,
                                                       0).astype(np.uint8)
                        pylab.subplot(4, 4, 2 * i + 2)
                        pylab.imshow(tmp)
                        pylab.axis("off")
                        pylab.savefig(
                            "outdir/visualize_{}.png".format(iterations))

                munit.train()

            print(f"iter: {iterations} loss: {loss.data}")
Example #3
0
# Hyperparameters
num_classes = 10
learning_rate = 1e-4
batch_size = 50
num_epochs = 5

#load dataset
train_dataset = datasets.MNIST(root='dataset/',
                               train=True,
                               transform=GraphTransform(device),
                               download=False)
train_loader = DataLoader(dataset=train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          collate_fn=CollateFn(device))
test_dataset = datasets.MNIST(root='dataset/',
                              train=False,
                              transform=GraphTransform(device),
                              download=False)
test_loader = DataLoader(dataset=test_dataset,
                         batch_size=batch_size,
                         shuffle=True,
                         collate_fn=CollateFn(device))


# to create A+I, but please see MNIST.py, this is just a referece
def adj_head(m):

    M = m**2
Example #4
0
def train(epochs, batchsize, s_interval, c_weight, kl_weight, x_path, y_path):
    generator = Generator()
    generator.cuda()
    generator.train()

    content_discriminator = ContentDiscriminator()
    content_discriminator.cuda()
    content_discriminator.train()

    domain_x_discriminator = DomainDiscriminator()
    domain_x_discriminator.cuda()
    domain_x_discriminator.train()

    domain_y_discriminator = DomainDiscriminator()
    domain_y_discriminator.cuda()
    domain_y_discriminator.train()

    g_optim = torch.optim.Adam(generator.parameters(), lr=1e-4, betas=(0.5, 0.999))
    cdis_optim = torch.optim.Adam(content_discriminator.parameters(), lr=1e-4, betas=(0.5, 0.999))
    ddis_x_optim = torch.optim.Adam(domain_x_discriminator.parameters(), lr=1e-4, betas=(0.5, 0.999))
    ddis_y_optim = torch.optim.Adam(domain_y_discriminator.parameters(), lr=1e-4, betas=(0.5, 0.999))

    dataset = HairDataset(medium_path=x_path, twin_path=y_path)
    collator = CollateFn()

    iterations = 0

    for epoch in range(epochs):
        dataloader = DataLoader(dataset,
                                                           batch_size=batchsize,
                                                           shuffle=True,
                                                           collate_fn=collator.train,
                                                           drop_last=True,
                                                           num_workers=0)
        progress_bar = tqdm(dataloader)

        for index, data in enumerate(progress_bar):
            iterations += 1
            x, y = data

            # discriminator update
            enc_x, enc_y, _, _, fake_x, fake_y, _, _, infers_x, infers_y = generator.forward(x, y)
            _, infer_x, _ = infers_x
            _, infer_y, _ = infers_y
            dis_loss = adversarial_content_D(content_discriminator, enc_x, enc_y)
            dis_loss += adversarial_domain_D(domain_x_discriminator, fake_x, x)
            dis_loss += adversarial_domain_D(domain_y_discriminator ,fake_y, y)
            dis_loss += adversarial_domain_D(domain_x_discriminator, infer_x, x)
            dis_loss += adversarial_domain_D(domain_y_discriminator, infer_y, y)

            cdis_optim.zero_grad()
            ddis_x_optim.zero_grad()
            ddis_y_optim.zero_grad()
            dis_loss.backward()
            cdis_optim.step()
            ddis_x_optim.step()
            ddis_y_optim.step()

            # generator update
            enc_x, enc_y, attr_x, attr_y, fake_x, fake_y, recon_x, recon_y, infers_x, infers_y = generator.forward(x ,y)
            latent_x, infer_x, infer_attr_x = infers_x
            latent_y, infer_y, infer_attr_y = infers_y
            _, _, _, _, fake_xyx, fake_yxy, _, _, _, _ = generator.forward(fake_x, fake_y)
            gen_loss = adversarial_content_G(content_discriminator, enc_x, enc_y)
            gen_loss += adversarial_domain_G(domain_x_discriminator, fake_x)
            gen_loss += adversarial_domain_G(domain_y_discriminator, fake_y)
            gen_loss += adversarial_domain_G(domain_x_discriminator, infer_x)
            gen_loss += adversarial_domain_G(domain_y_discriminator, infer_y)
            gen_loss += c_weight * cross_cycle_consistency_loss(x, y, fake_xyx, fake_yxy)
            gen_loss += c_weight * cross_cycle_consistency_loss(x, y, recon_x, recon_y)
            gen_loss += c_weight * cross_cycle_consistency_loss(latent_x, latent_y, infer_attr_x, infer_attr_y)
            #gen_loss += kl_weight * (l2_regularize(attr_x) + l2_regularize(attr_y))

            g_optim.zero_grad()
            gen_loss.backward()
            g_optim.step()

            if iterations % s_interval == 1:
                torch.save(generator.state_dict(), './model/model_{}.pt'.format(iterations))

                pylab.rcParams['figure.figsize'] = (16.0,16.0)
                pylab.clf()

                with torch.no_grad():
                    _, _, _, _, _, fake_y, _, _, _, _ = generator.forward(x, y)
                    fake_y = fake_y[:2].detach().cpu().numpy()
                    real_x = x[:2].detach().cpu().numpy()

                    for i in range(1):
                        tmp = (np.clip(real_x[i] * 127.5 + 127.5, 0, 255)).transpose(1, 2, 0).astype(np.uint8)
                        pylab.subplot(2, 2, 2 * i + 1)
                        pylab.imshow(tmp)
                        pylab.axis("off")
                        pylab.savefig("outdir/visualize_{}.png".format(iterations))
                        tmp = (np.clip(fake_y[i] * 127.5 + 127.5, 0, 255)).transpose(1, 2, 0).astype(np.uint8)
                        pylab.subplot(2, 2, 2 * i + 2)
                        pylab.imshow(tmp)
                        pylab.axis("off")
                        pylab.savefig("outdir/visualize_{}.png".format(iterations))

            print('iteration: {} dis loss: {} gen loss: {}'.format(iterations, dis_loss, gen_loss))
Example #5
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(
        description='GNN baselines on pcqm4m with PGL')
    parser.add_argument('--use_cuda', action='store_true')
    parser.add_argument('--device',
                        type=int,
                        default=0,
                        help='which gpu to use if any (default: 0)')
    parser.add_argument(
        '--gnn',
        type=str,
        default='gin-virtual',
        help=
        'GNN gin, gin-virtual, or gcn, or gcn-virtual (default: gin-virtual)')
    parser.add_argument(
        '--graph_pooling',
        type=str,
        default='sum',
        help='graph pooling strategy mean or sum (default: sum)')
    parser.add_argument('--drop_ratio',
                        type=float,
                        default=0,
                        help='dropout ratio (default: 0)')
    parser.add_argument(
        '--num_layers',
        type=int,
        default=5,
        help='number of GNN message passing layers (default: 5)')
    parser.add_argument(
        '--emb_dim',
        type=int,
        default=600,
        help='dimensionality of hidden units in GNNs (default: 600)')
    parser.add_argument('--train_subset', action='store_true')
    parser.add_argument('--batch_size',
                        type=int,
                        default=256,
                        help='input batch size for training (default: 256)')
    parser.add_argument('--epochs',
                        type=int,
                        default=100,
                        help='number of epochs to train (default: 100)')
    parser.add_argument('--num_workers',
                        type=int,
                        default=1,
                        help='number of workers (default: 1)')
    parser.add_argument('--log_dir',
                        type=str,
                        default="",
                        help='tensorboard log directory')
    parser.add_argument('--checkpoint_dir',
                        type=str,
                        default='',
                        help='directory to save checkpoint')
    parser.add_argument('--save_test_dir',
                        type=str,
                        default='',
                        help='directory to save test submission file')
    args = parser.parse_args()

    print(args)

    random.seed(42)
    np.random.seed(42)
    paddle.seed(42)

    if not args.use_cuda:
        paddle.set_device("cpu")

    ### automatic dataloading and splitting
    class Config():
        def __init__(self):
            self.base_data_path = "./dataset"

    config = Config()
    ds = MolDataset(config)
    split_idx = ds.get_idx_split()
    test_ds = Subset(ds, split_idx['test'])

    print("Test exapmles: ", len(test_ds))

    ### automatic evaluator. takes dataset name as input
    evaluator = PCQM4MEvaluator()

    test_loader = Dataloader(test_ds,
                             batch_size=args.batch_size,
                             shuffle=False,
                             num_workers=args.num_workers,
                             collate_fn=CollateFn())

    shared_params = {
        'num_layers': args.num_layers,
        'emb_dim': args.emb_dim,
        'drop_ratio': args.drop_ratio,
        'graph_pooling': args.graph_pooling
    }

    if args.gnn == 'gin':
        model = GNN(gnn_type='gin', virtual_node=False, **shared_params)
    elif args.gnn == 'gin-virtual':
        model = GNN(gnn_type='gin', virtual_node=True, **shared_params)
    elif args.gnn == 'gcn':
        model = GNN(gnn_type='gcn', virtual_node=False, **shared_params)
    elif args.gnn == 'gcn-virtual':
        model = GNN(gnn_type='gcn', virtual_node=True, **shared_params)
    else:
        raise ValueError('Invalid GNN type')

    num_params = sum(p.numel() for p in model.parameters())
    print(f'#Params: {num_params}')

    checkpoint_path = os.path.join(args.checkpoint_dir, 'checkpoint.pdparams')
    if not os.path.exists(checkpoint_path):
        raise RuntimeError(f'Checkpoint file not found at {checkpoint_path}')

    model.set_state_dict(paddle.load(checkpoint_path))

    print('Predicting on test data...')
    y_pred = test(model, test_loader)
    print('Saving test submission file...')
    evaluator.save_test_submission({'y_pred': y_pred}, args.save_test_dir)
Example #6
0
def main():
    utils.writer = SummaryWriter()

    parser = argparse.ArgumentParser()
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='Disables CUDA training.')
    parser.add_argument('--epochs',
                        type=int,
                        default=700,
                        help='Number of epochs to train.')
    parser.add_argument('--link-pred',
                        action='store_true',
                        default=False,
                        help='Enable Link Prediction Loss')
    parser.add_argument('--dataset',
                        default='ENZYMES',
                        help="Choose dataset: ENZYMES, DD")
    parser.add_argument('--batch-size',
                        default=256,
                        type=int,
                        help="Choose dataset: ENZYMES, DD")
    parser.add_argument('--train-ratio',
                        default=0.9,
                        type=float,
                        help="Train/Val split ratio")
    parser.add_argument('--pool-ratio',
                        default=0.25,
                        type=float,
                        help="Train/Val split ratio")

    args = parser.parse_args()
    utils.writer.add_text("args", str(args))
    device = "cuda" if not args.no_cuda and torch.cuda.is_available(
    ) else "cpu"

    dataset = TUDataset(args.dataset)
    # dataset = MNIST(root="~/.torch/data/", transform=GraphTransform(device), download=True)
    dataset_size = len(dataset)
    train_size = int(dataset_size * args.train_ratio)
    test_size = dataset_size - train_size
    max_num_nodes = max([item[0][0].shape[0] for item in dataset])
    n_classes = int(max([item[1] for item in dataset])) + 1
    train_data, test_data = random_split(dataset, (train_size, test_size))
    input_shape = int(dataset[0][0][1].shape[-1])
    train_loader = DataLoader(train_data,
                              batch_size=args.batch_size,
                              shuffle=True,
                              collate_fn=CollateFn(device))
    test_loader = DataLoader(test_data,
                             batch_size=args.batch_size,
                             shuffle=True,
                             collate_fn=CollateFn(device))

    model = BatchedModel(pool_size=int(max_num_nodes * args.pool_ratio),
                         device=device,
                         link_pred=args.link_pred,
                         input_shape=input_shape,
                         n_classes=n_classes).to(device)
    model.train()
    optimizer = optim.Adam(model.parameters())

    for e in tqdm(range(args.epochs)):
        utils.e = e
        epoch_losses_list = []
        true_sample = 0
        model.train()
        for i, (adj, features, masks, batch_labels) in enumerate(train_loader):
            utils.train_iter += 1
            graph_feat = model(features, adj, masks)
            output = model.classifier(graph_feat)
            loss = model.loss(output, batch_labels)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 2.0)
            optimizer.step()
            optimizer.zero_grad()

            epoch_losses_list.append(loss.item())
            iter_true_sample = (output.argmax(dim=1).long() == batch_labels.long()). \
                float().sum().item()
            iter_acc = float(iter_true_sample) / output.shape[0]
            utils.writer.add_scalar("iter train acc", iter_acc,
                                    utils.train_iter)
            print(f"{utils.train_iter} iter train acc: {iter_acc}")
            true_sample += iter_true_sample

        acc = true_sample / train_size
        utils.writer.add_scalar("Epoch Acc", acc, e)
        tqdm.write(f"Epoch:{e}  \t train_acc:{acc:.2f}")

        test_loss_list = []
        true_sample = 0
        model.eval()
        with torch.no_grad():
            for i, (adj, features, masks,
                    batch_labels) in enumerate(test_loader):
                utils.test_iter += 1
                graph_feat = model(features, adj, masks)
                output = model.classifier(graph_feat)
                loss = model.loss(output, batch_labels)
                test_loss_list.append(loss.item())
                iter_true_sample = (output.argmax(dim=1).long() == batch_labels.long()). \
                    float().sum().item()
                iter_acc = float(iter_true_sample) / output.shape[0]
                utils.writer.add_scalar("iter test acc", iter_acc,
                                        utils.test_iter)
                print(f"{utils.test_iter} iter test acc: {iter_acc}")
                true_sample += iter_true_sample
        acc = true_sample / test_size
        utils.writer.add_scalar("Epoch Acc", acc, e)
        tqdm.write(f"Epoch:{e}  \t val_acc:{acc:.2f}")
Example #7
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(
        description='GNN baselines on pcqm4m with PGL')
    parser.add_argument('--use_cuda', action='store_true')
    parser.add_argument('--device',
                        type=int,
                        default=0,
                        help='which gpu to use if any (default: 0)')
    parser.add_argument(
        '--gnn',
        type=str,
        default='gin-virtual',
        help=
        'GNN gin, gin-virtual, or gcn, or gcn-virtual (default: gin-virtual)')
    parser.add_argument(
        '--graph_pooling',
        type=str,
        default='sum',
        help='graph pooling strategy mean or sum (default: sum)')
    parser.add_argument('--drop_ratio',
                        type=float,
                        default=0,
                        help='dropout ratio (default: 0)')
    parser.add_argument(
        '--num_layers',
        type=int,
        default=5,
        help='number of GNN message passing layers (default: 5)')
    parser.add_argument(
        '--emb_dim',
        type=int,
        default=600,
        help='dimensionality of hidden units in GNNs (default: 600)')
    parser.add_argument('--train_subset', action='store_true')
    parser.add_argument('--batch_size',
                        type=int,
                        default=256,
                        help='input batch size for training (default: 256)')
    parser.add_argument('--epochs',
                        type=int,
                        default=100,
                        help='number of epochs to train (default: 100)')
    parser.add_argument('--num_workers',
                        type=int,
                        default=1,
                        help='number of workers (default: 1)')
    parser.add_argument('--log_dir',
                        type=str,
                        default="",
                        help='tensorboard log directory')
    parser.add_argument('--checkpoint_dir',
                        type=str,
                        default='',
                        help='directory to save checkpoint')
    parser.add_argument('--save_test_dir',
                        type=str,
                        default='',
                        help='directory to save test submission file')
    args = parser.parse_args()

    print(args)

    random.seed(42)
    np.random.seed(42)
    paddle.seed(42)

    if not args.use_cuda:
        paddle.set_device("cpu")

    ### automatic dataloading and splitting
    class Config():
        def __init__(self):
            self.base_data_path = "./dataset"

    config = Config()
    ds = MolDataset(config)

    split_idx = ds.get_idx_split()
    train_ds = Subset(ds, split_idx['train'])
    valid_ds = Subset(ds, split_idx['valid'])
    test_ds = Subset(ds, split_idx['test'])

    print("Train exapmles: ", len(train_ds))
    print("Valid exapmles: ", len(valid_ds))
    print("Test exapmles: ", len(test_ds))

    ### automatic evaluator. takes dataset name as input
    evaluator = PCQM4MEvaluator()

    train_loader = Dataloader(train_ds,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.num_workers,
                              collate_fn=CollateFn())

    valid_loader = Dataloader(valid_ds,
                              batch_size=args.batch_size,
                              shuffle=False,
                              num_workers=args.num_workers,
                              collate_fn=CollateFn())

    if args.save_test_dir is not '':
        test_loader = Dataloader(test_ds,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=args.num_workers,
                                 collate_fn=CollateFn())

    if args.checkpoint_dir is not '':
        os.makedirs(args.checkpoint_dir, exist_ok=True)

    shared_params = {
        'num_layers': args.num_layers,
        'emb_dim': args.emb_dim,
        'drop_ratio': args.drop_ratio,
        'graph_pooling': args.graph_pooling
    }

    if args.gnn == 'gin':
        model = GNN(gnn_type='gin', virtual_node=False, **shared_params)
    elif args.gnn == 'gin-virtual':
        model = GNN(gnn_type='gin', virtual_node=True, **shared_params)
    elif args.gnn == 'gcn':
        model = GNN(gnn_type='gcn', virtual_node=False, **shared_params)
    elif args.gnn == 'gcn-virtual':
        model = GNN(gnn_type='gcn', virtual_node=True, **shared_params)
    else:
        raise ValueError('Invalid GNN type')

    num_params = sum(p.numel() for p in model.parameters())
    print(f'#Params: {num_params}')

    if args.log_dir is not '':
        writer = SummaryWriter(log_dir=args.log_dir)

    best_valid_mae = 1000

    scheduler = paddle.optimizer.lr.StepDecay(learning_rate=0.001,
                                              step_size=300,
                                              gamma=0.25)

    optimizer = paddle.optimizer.Adam(learning_rate=scheduler,
                                      parameters=model.parameters())

    msg = "ogbg_lsc_paddle_baseline\n"
    for epoch in range(1, args.epochs + 1):
        print("=====Epoch {}".format(epoch))
        print('Training...')
        train_mae = train(model, train_loader, optimizer)

        print('Evaluating...')
        valid_mae = eval(model, valid_loader, evaluator)

        print({'Train': train_mae, 'Validation': valid_mae})

        if args.log_dir is not '':
            writer.add_scalar('valid/mae', valid_mae, epoch)
            writer.add_scalar('train/mae', train_mae, epoch)

        if valid_mae < best_valid_mae:
            best_valid_mae = valid_mae
            if args.checkpoint_dir is not '':
                print('Saving checkpoint...')
                paddle.save(
                    model.state_dict(),
                    os.path.join(args.checkpoint_dir, 'checkpoint.pdparams'))

            if args.save_test_dir is not '':
                print('Predicting on test data...')
                y_pred = test(model, test_loader)
                print('Saving test submission file...')
                evaluator.save_test_submission({'y_pred': y_pred},
                                               args.save_test_dir)

        scheduler.step()

        print(f'Best validation MAE so far: {best_valid_mae}')

        try:
            msg +="Epoch: %d | Train: %.6f | Valid: %.6f | Best Valid: %.6f\n" \
                    % (epoch, train_mae, valid_mae, best_valid_mae)
            print(msg)
        except:
            continue

    if args.log_dir is not '':
        writer.close()
Example #8
0
File: main.py Project: Yelrose/PGL
def main(config):
    if dist.get_world_size() > 1:
        dist.init_parallel_env()

    if dist.get_rank() == 0:
        timestamp = datetime.now().strftime("%Hh%Mm%Ss")
        log_path = os.path.join(config.log_dir,
                                "tensorboard_log_%s" % timestamp)
        writer = SummaryWriter(log_path)

    log.info("loading data")
    raw_dataset = GraphPropPredDataset(name=config.dataset_name)
    config.num_class = raw_dataset.num_tasks
    config.eval_metric = raw_dataset.eval_metric
    config.task_type = raw_dataset.task_type

    mol_dataset = MolDataset(config,
                             raw_dataset,
                             transform=make_multihop_edges)
    splitted_index = raw_dataset.get_idx_split()
    train_ds = Subset(mol_dataset, splitted_index['train'], mode='train')
    valid_ds = Subset(mol_dataset, splitted_index['valid'], mode="valid")
    test_ds = Subset(mol_dataset, splitted_index['test'], mode="test")

    log.info("Train Examples: %s" % len(train_ds))
    log.info("Val Examples: %s" % len(valid_ds))
    log.info("Test Examples: %s" % len(test_ds))

    fn = CollateFn(config)

    train_loader = Dataloader(train_ds,
                              batch_size=config.batch_size,
                              shuffle=True,
                              num_workers=config.num_workers,
                              collate_fn=fn)

    valid_loader = Dataloader(valid_ds,
                              batch_size=config.batch_size,
                              num_workers=config.num_workers,
                              collate_fn=fn)

    test_loader = Dataloader(test_ds,
                             batch_size=config.batch_size,
                             num_workers=config.num_workers,
                             collate_fn=fn)

    model = ClassifierNetwork(config.hidden_size, config.out_dim,
                              config.num_layers, config.dropout_prob,
                              config.virt_node, config.K, config.conv_type,
                              config.appnp_hop, config.alpha)
    model = paddle.DataParallel(model)

    optim = Adam(learning_rate=config.lr, parameters=model.parameters())
    criterion = nn.loss.BCEWithLogitsLoss()

    evaluator = Evaluator(config.dataset_name)

    best_valid = 0

    global_step = 0
    for epoch in range(1, config.epochs + 1):
        model.train()
        for idx, batch_data in enumerate(train_loader):
            g, mh_graphs, labels, unmask = batch_data
            g = g.tensor()
            multihop_graphs = []
            for item in mh_graphs:
                multihop_graphs.append(item.tensor())
            g.multi_hop_graphs = multihop_graphs
            labels = paddle.to_tensor(labels)
            unmask = paddle.to_tensor(unmask)

            pred = model(g)
            pred = paddle.masked_select(pred, unmask)
            labels = paddle.masked_select(labels, unmask)
            train_loss = criterion(pred, labels)
            train_loss.backward()
            optim.step()
            optim.clear_grad()

            if global_step % 80 == 0:
                message = "train: epoch %d | step %d | " % (epoch, global_step)
                message += "loss %.6f" % (train_loss.numpy())
                log.info(message)
                if dist.get_rank() == 0:
                    writer.add_scalar("loss", train_loss.numpy(), global_step)
            global_step += 1

        valid_result = evaluate(model, valid_loader, criterion, evaluator)
        message = "valid: epoch %d | step %d | " % (epoch, global_step)
        for key, value in valid_result.items():
            message += " | %s %.6f" % (key, value)
            if dist.get_rank() == 0:
                writer.add_scalar("valid_%s" % key, value, global_step)
        log.info(message)

        test_result = evaluate(model, test_loader, criterion, evaluator)
        message = "test: epoch %d | step %d | " % (epoch, global_step)
        for key, value in test_result.items():
            message += " | %s %.6f" % (key, value)
            if dist.get_rank() == 0:
                writer.add_scalar("test_%s" % key, value, global_step)
        log.info(message)

        if best_valid < valid_result[config.metrics]:
            best_valid = valid_result[config.metrics]
            best_valid_result = valid_result
            best_test_result = test_result

        message = "best result: epoch %d | " % (epoch)
        message += "valid %s: %.6f | " % (config.metrics,
                                          best_valid_result[config.metrics])
        message += "test %s: %.6f | " % (config.metrics,
                                         best_test_result[config.metrics])
        log.info(message)

    message = "final eval best result:%.6f" % best_valid_result[config.metrics]
    log.info(message)
    message = "final test best result:%.6f" % best_test_result[config.metrics]
    log.info(message)