Exemple #1
0
parser.add_argument('--weight_decay', type=float, default=5e-4)
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--epochs', type=int, default=300)
parser.add_argument('--seed', type=int, default=1)
args = parser.parse_args()

args.data_fp = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data',
                        args.dataset)
device = torch.device('cuda', args.device_idx)

# deterministic
torch.manual_seed(args.seed)
cudnn.benchmark = False
cudnn.deterministic = True

train_dataset = MNISTSuperpixels(args.data_fp, True, pre_transform=T.Polar())
test_dataset = MNISTSuperpixels(args.data_fp, False, pre_transform=T.Polar())
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64)


def normalized_cut_2d(edge_index, pos):
    row, col = edge_index
    edge_attr = torch.norm(pos[row] - pos[col], p=2, dim=1)
    return normalized_cut(edge_index, edge_attr, num_nodes=pos.size(0))


class MoNet(torch.nn.Module):
    def __init__(self, kernel_size):
        super(MoNet, self).__init__()
        self.conv1 = GMMConv(1, 32, dim=2, kernel_size=kernel_size)
Exemple #2
0
    # SuperPixels dataset
    sp_path = os.path.join(os.path.dirname(os.path.realpath("/")),
                           "MNISTSuperpixel")
    sp_test_dataset = MNISTSuperpixels(sp_path,
                                       train=False,
                                       transform=T.Cartesian())
    sp_train_loader_25, sp_val_loader_25 = get_train_val_loader(
        sp_path, train_batch_size=25, val_batch_size=25)
    sp_train_loader_64, sp_val_loader_64 = get_train_val_loader(
        sp_path, train_batch_size=64, val_batch_size=64)
    sp_test_loader = DataLoader(sp_test_dataset, batch_size=1, shuffle=False)

    # Skeletons dataset
    sk_path = "dataset"
    sk_train_dataset = MNISTSkeleton(sk_path, "train", transform=T.Polar())
    sk_test_dataset = MNISTSkeleton(sk_path, "test", transform=T.Polar())
    sk_val_dataset = MNISTSkeleton(sk_path, "val", transform=T.Polar())

    sk_train_loader = DataLoader(sk_train_dataset, batch_size=64, shuffle=True)
    sk_val_loader = DataLoader(sk_val_dataset, batch_size=64, shuffle=False)
    sk_test_loader = DataLoader(sk_test_dataset, batch_size=64, shuffle=False)

    # Train
    num_epochs = 150
    print('MoNet Superpixel starts:')
    process_model(MoNet,
                  'MoNet_SuperPixels.txt',
                  sp_train_loader_25,
                  sp_val_loader_25,
                  sp_test_loader,
Exemple #3
0
def main(args):
    args = init_dirs(args)

    pt = T.Cartesian() if args.cartesian else T.Polar()

    if args.dataset == 'sp':
        train_dataset = MNISTSuperpixels(args.dataset_path, True, pre_transform=pt)
        test_dataset = MNISTSuperpixels(args.dataset_path, False, pre_transform=pt)
        train_loader = tgDataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
        test_loader = tgDataLoader(test_dataset, batch_size=args.batch_size)
    elif args.dataset == 'sm':
        train_dataset = MNISTGraphDataset(args.dataset_path, args.num_hits, train=True)
        train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size, pin_memory=True)
        test_dataset = MNISTGraphDataset(args.dataset_path, args.num_hits, train=False)
        test_loader = DataLoader(test_dataset, shuffle=True, batch_size=args.batch_size, pin_memory=True)

    if(args.load_model):
        C = torch.load(args.model_path + args.name + "/C_" + str(args.start_epoch) + ".pt").to(device)
    else:
        C = MoNet(args.kernel_size).to(device)

    C_optimizer = torch.optim.Adam(C.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    if(args.scheduler):
        C_scheduler = torch.optim.lr_scheduler.StepLR(C_optimizer, args.decay_step, gamma=args.lr_decay)

    train_losses = []
    test_losses = []

    def plot_losses(epoch, train_losses, test_losses):
        fig = plt.figure()
        ax1 = fig.add_subplot(1, 2, 1)
        ax1.plot(train_losses)
        ax1.set_title('training')
        ax2 = fig.add_subplot(1, 2, 2)
        ax2.plot(test_losses)
        ax2.set_title('testing')

        plt.savefig(args.losses_path + args.name + "/" + str(epoch) + ".png")
        plt.close()

    def save_model(epoch):
        torch.save(C, args.model_path + args.name + "/C_" + str(epoch) + ".pt")

    def train_C(data, y):
        C.train()
        C_optimizer.zero_grad()

        output = C(data)

        # nll_loss takes class labels as target, so one-hot encoding is not needed
        C_loss = F.nll_loss(output, y)

        C_loss.backward()
        C_optimizer.step()

        return C_loss.item()

    def test(epoch):
        C.eval()
        test_loss = 0
        correct = 0
        with torch.no_grad():
            for data in test_loader:
                if args.dataset == 'sp':
                    output = C(data.to(device))
                    y = data.y.to(device)
                elif args.dataset == 'sm':
                    output = C(tg_transform(args, data[0].to(device)))
                    y = data[1].to(device)

                test_loss += F.nll_loss(output, y, size_average=False).item()
                pred = output.data.max(1, keepdim=True)[1]
                correct += pred.eq(y.data.view_as(pred)).sum()

        test_loss /= len(test_loader.dataset)
        test_losses.append(test_loss)

        print('test')

        f = open(args.out_path + args.name + '.txt', 'a')
        print(args.out_path + args.name + '.txt')
        s = "After {} epochs, on test set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(epoch, test_loss, correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset))
        print(s)
        f.write(s)
        f.close()

    for i in range(args.start_epoch, args.num_epochs):
        print("Epoch %d %s" % ((i + 1), args.name))
        C_loss = 0
        test(i)
        for batch_ndx, data in tqdm(enumerate(train_loader), total=len(train_loader)):
            if args.dataset == 'sp':
                C_loss += train_C(data.to(device), data.y.to(device))
            elif args.dataset == 'sm':
                C_loss += train_C(tg_transform(args, data[0].to(device)), data[1].to(device))

        train_losses.append(C_loss / len(train_loader))

        if(args.scheduler):
            C_scheduler.step()

        if((i + 1) % 10 == 0):
            save_model(i + 1)
            plot_losses(i + 1, train_losses, test_losses)

    test(args.num_epochs)
Exemple #4
0
 def prepare_data(self):
     self.train_dataset = MNISTSuperpixels(args.data_fp, True, pre_transform=T.Compose([remove_self_loops, T.Polar()]))
     self.test_dataset = MNISTSuperpixels(args.data_fp, False, pre_transform=T.Compose([remove_self_loops, T.Polar()]))
def graph_transform_1():
   return transforms_G.Compose([
       transforms_G.Distance(norm=False, cat=True),
       transforms_G.Cartesian(norm=False, cat=True),
       transforms_G.Polar(norm=False, cat=True)
   ])
Exemple #6
0
def load_dataset(args):
    # automatic data loading and splitting
    transform = add_zeros if args.dataset == 'ogbg-ppa' else None
    cls_criterion = get_loss_function(args.dataset)
    idx2word_mapper = None

    if args.dataset == 'mnist':
        train_data = MNISTSuperpixels(root='dataset',
                                      train=True,
                                      transform=T.Polar())
        dataset = train_data
        dataset.name = 'mnist'
        dataset.eval_metric = 'acc'
        validation_data = []
        test_data = MNISTSuperpixels(root='dataset',
                                     train=False,
                                     transform=T.Polar())

        train_data = list(train_data)
        test_data = list(test_data)

    elif args.dataset == 'QM9':
        # Contains 19 targets. Use only the first 12 (0-11)
        QM9_VALIDATION_START = 110000
        QM9_VALIDATION_END = 120000
        dataset = QM9(root='dataset',
                      transform=ExtractTargetTransform(args.target)).shuffle()
        dataset.name = 'QM9'
        dataset.eval_metric = 'mae'

        train_data = dataset[:QM9_VALIDATION_START]
        validation_data = dataset[QM9_VALIDATION_START:QM9_VALIDATION_END]
        test_data = dataset[QM9_VALIDATION_END:]

        train_data = list(train_data)
        validation_data = list(validation_data)
        test_data = list(test_data)

    elif args.dataset == 'zinc':
        train_data = ZINC(root='dataset', subset=True, split='train')

        dataset = train_data
        dataset.name = 'zinc'
        validation_data = ZINC(root='dataset', subset=True, split='val')
        test_data = ZINC(root='dataset', subset=True, split='test')
        dataset.eval_metric = 'mae'

        train_data = list(train_data)
        validation_data = list(validation_data)
        test_data = list(test_data)

    elif args.dataset in [
            'ogbg-molhiv', 'ogbg-molpcba', 'ogbg-ppa', 'ogbg-code2'
    ]:
        dataset = PygGraphPropPredDataset(name=args.dataset,
                                          transform=transform)

        if args.dataset == 'obgb-code2':
            seq_len_list = np.array([len(seq) for seq in dataset.data.y])
            max_seq_len = args.max_seq_len
            num_less_or_equal_to_max = np.sum(
                seq_len_list <= args.max_seq_len) / len(seq_len_list)
            print(
                f'Target sequence less or equal to {max_seq_len} is {num_less_or_equal_to_max}%.'
            )

        split_idx = dataset.get_idx_split()
        # The following is only used in the evaluation of the ogbg-code classifier.
        if args.dataset == 'ogbg-code2':
            vocab2idx, idx2vocab = get_vocab_mapping(
                [dataset.data.y[i] for i in split_idx['train']],
                args.num_vocab)
            # specific transformations for the ogbg-code dataset
            dataset.transform = transforms.Compose([
                augment_edge,
                lambda data: encode_y_to_arr(data, vocab2idx, args.max_seq_len)
            ])
            idx2word_mapper = partial(decode_arr_to_seq, idx2vocab=idx2vocab)

        train_data = list(dataset[split_idx["train"]])
        validation_data = list(dataset[split_idx["valid"]])
        test_data = list(dataset[split_idx["test"]])

    return dataset, train_data, validation_data, test_data, cls_criterion, idx2word_mapper