Пример #1
0
def test_get_prototypes():
    # Numpy
    num_classes = 3
    embeddings_np = np.random.rand(2, 5, 7).astype(np.float32)
    targets_np = np.random.randint(0, num_classes, size=(2, 5))

    # PyTorch
    embeddings_th = torch.as_tensor(embeddings_np)
    targets_th = torch.as_tensor(targets_np)
    prototypes_th = get_prototypes(embeddings_th, targets_th, num_classes)

    assert prototypes_th.shape == (2, num_classes, 7)
    assert prototypes_th.dtype == embeddings_th.dtype

    prototypes_np = np.zeros((2, num_classes, 7), dtype=np.float32)
    num_samples_np = np.zeros((2, num_classes), dtype=np.int_)
    for i in range(2):
        for j in range(5):
            num_samples_np[i, targets_np[i, j]] += 1
            for k in range(7):
                prototypes_np[i, targets_np[i, j], k] += embeddings_np[i, j, k]

    for i in range(2):
        for j in range(num_classes):
            for k in range(7):
                prototypes_np[i, j, k] /= max(num_samples_np[i, j], 1)

    np.testing.assert_allclose(prototypes_th.detach().numpy(), prototypes_np)
def rep_memory(args, model, memory_train):
    memory_loss = 0
    for dataidx, dataloader_dict in enumerate(memory_train):
        for dataname, memory_list in dataloader_dict.items():
            select = random.choice(memory_list)
            memory_train_inputs, memory_train_targets = select['train']
            memory_train_inputs = memory_train_inputs.to(device=args.device)
            memory_train_targets = memory_train_targets.to(device=args.device)
            if memory_train_inputs.size(2) == 1:
                memory_train_inputs = memory_train_inputs.repeat(1, 1, 3, 1, 1)
            memory_train_embeddings = model(memory_train_inputs, dataidx)

            memory_test_inputs, memory_test_targets = select['test']
            memory_test_inputs = memory_test_inputs.to(device=args.device)
            memory_test_targets = memory_test_targets.to(device=args.device)
            if memory_test_inputs.size(2) == 1:
                memory_test_inputs = memory_test_inputs.repeat(1, 1, 3, 1, 1)

            memory_test_embeddings = model(memory_test_inputs, dataidx)
            memory_prototypes = get_prototypes(memory_train_embeddings,
                                               memory_train_targets,
                                               args.num_way)
            memory_loss += prototypical_loss(memory_prototypes,
                                             memory_test_embeddings,
                                             memory_test_targets)

    return memory_loss
Пример #3
0
def test(device, testset, testloader, model):
    model.to(device)
    model.eval()
    acc = []
    with tqdm(testloader, total=config.num_batches) as pbar:
        for batch_idx, batch in enumerate(pbar):

            train_inputs, train_targets = batch['train']
            train_inputs = train_inputs.to(device=device)
            train_targets = train_targets.to(device=device)

            train_embeddings = model(train_inputs)

            test_inputs, test_targets = batch['test']
            test_inputs = test_inputs.to(device=device)
            test_targets = test_targets.to(device=device)

            test_embeddings = model(test_inputs)

            prototypes = get_prototypes(train_embeddings, train_targets,
                                        testset.num_classes_per_task)

            with torch.no_grad():
                accuracy = get_accuracy(prototypes, test_embeddings,
                                        test_targets)
                pbar.set_postfix(accuracy='{0:.4f}'.format(accuracy.item()))
                acc.append(accuracy)

            if batch_idx >= config.num_batches:
                break
    return acc
Пример #4
0
def train(args):
    logger.warning('This script is an example to showcase the extensions and '
                   'data-loading features of Torchmeta, and as such has been '
                   'very lightly tested.')

    dataset = omniglot(args.folder,
                       shots=args.num_shots,
                       ways=args.num_ways,
                       shuffle=True,
                       test_shots=15,
                       meta_train=True,
                       download=args.download)
    dataloader = BatchMetaDataLoader(dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.num_workers)

    model = PrototypicalNetwork(1,
                                args.embedding_size,
                                hidden_size=args.hidden_size)
    model.to(device=args.device)
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    # Training loop
    with tqdm(dataloader, total=args.num_batches) as pbar:
        for batch_idx, batch in enumerate(pbar):
            model.zero_grad()

            train_inputs, train_targets = batch['train']
            train_inputs = train_inputs.to(device=args.device)
            train_targets = train_targets.to(device=args.device)
            train_embeddings = model(train_inputs)

            test_inputs, test_targets = batch['test']
            test_inputs = test_inputs.to(device=args.device)
            test_targets = test_targets.to(device=args.device)
            test_embeddings = model(test_inputs)

            prototypes = get_prototypes(train_embeddings, train_targets,
                dataset.num_classes_per_task)
            loss = prototypical_loss(prototypes, test_embeddings, test_targets)

            loss.backward()
            optimizer.step()

            with torch.no_grad():
                accuracy = get_accuracy(prototypes, test_embeddings, test_targets)
                pbar.set_postfix(accuracy='{0:.4f}'.format(accuracy.item()))

            if batch_idx >= args.num_batches:
                break

    # Save model
    if args.output_folder is not None:
        filename = os.path.join(args.output_folder, 'protonet_omniglot_'
            '{0}shot_{1}way.pt'.format(args.num_shots, args.num_ways))
        with open(filename, 'wb') as f:
            state_dict = model.state_dict()
            torch.save(state_dict, f)
Пример #5
0
def train(device, dataset, dataloader, model):
    print("in train")
    model = model.to(device)
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)

    # Training loop
    images_per_batch = {}
    batch_count, images_per_batch['train'], images_per_batch[
        'test'] = 0, [], []
    with tqdm(dataloader, total=config.num_batches) as pbar:
        for batch_idx, batch in enumerate(pbar):
            model.zero_grad()

            train_inputs, train_targets = batch['train']
            train_inputs = train_inputs.to(device=device)
            train_targets = train_targets.to(device=device)
            train_embeddings = model(train_inputs)

            test_inputs, test_targets = batch['test']
            test_inputs = test_inputs.to(device=device)
            test_targets = test_targets.to(device=device)
            test_embeddings = model(test_inputs)

            prototypes = get_prototypes(train_embeddings, train_targets,
                                        dataset.num_classes_per_task)
            loss = prototypical_loss(prototypes, test_embeddings, test_targets)

            loss.backward()
            optimizer.step()

            #Just keeping the count here
            batch_count += 1
            images_per_batch['train'].append(train_inputs.shape[1])
            images_per_batch['test'].append(test_inputs.shape[1])

            with torch.no_grad():
                accuracy = get_accuracy(prototypes, test_embeddings,
                                        test_targets)
                pbar.set_postfix(accuracy='{0:.4f}'.format(accuracy.item()))

            if batch_idx >= config.num_batches:
                break

    print("Number of batches in the dataloader: ", batch_count)

    # Save model
    if check_dir() is not None:
        filename = os.path.join(
            'saved_models',
            'protonet_cifar_fs_{0}shot_{1}way.pt'.format(config.k, config.n))
        with open(filename, 'wb') as f:
            state_dict = model.state_dict()
            torch.save(state_dict, f)
            print("Model saved")

    return batch_count, images_per_batch
Пример #6
0
    def valid(self, dataloader_dict, domain_id, epoch):
        self.model.eval()
        acc_dict = {}
        acc_list = []
        for dataname, dataloader in dataloader_dict.items():
            with torch.no_grad():
                with tqdm(dataloader,
                          total=self.args.num_valid_batches) as pbar:
                    for batch_idx, batch in enumerate(pbar):
                        self.model.zero_grad()
                        train_inputs, train_targets = batch['train']
                        train_inputs = train_inputs.to(device=self.args.device)
                        train_targets = train_targets.to(
                            device=self.args.device)
                        if train_inputs.size(2) == 1:
                            train_inputs = train_inputs.repeat(1, 1, 3, 1, 1)

                        train_embeddings = self.model(train_inputs, domain_id)
                        test_inputs, test_targets = batch['test']
                        test_inputs = test_inputs.to(device=self.args.device)
                        test_targets = test_targets.to(device=self.args.device)
                        if test_inputs.size(2) == 1:
                            test_inputs = test_inputs.repeat(1, 1, 3, 1, 1)
                        test_embeddings = self.model(test_inputs, domain_id)

                        prototypes = get_prototypes(train_embeddings,
                                                    train_targets,
                                                    self.args.num_way)
                        accuracy = get_accuracy(prototypes, test_embeddings,
                                                test_targets)
                        acc_list.append(accuracy.cpu().data.numpy())
                        pbar.set_description(
                            'dataname {} accuracy ={:.4f}'.format(
                                dataname, np.mean(acc_list)))
                        if batch_idx >= self.args.num_valid_batches:
                            break

            avg_accuracy = np.round(np.mean(acc_list), 4)
            acc_dict = {dataname: avg_accuracy}
            return acc_dict
Пример #7
0
    def train(self, epoch, dataloader_dict, memory_train=None, domain_id=None):
        self.model.train()
        for dataname, dataloader in dataloader_dict.items():
            with tqdm(dataloader, total=self.args.num_batches) as pbar:
                for batch_idx, batch in enumerate(pbar):
                    self.model.zero_grad()
                    train_inputs, train_targets = batch['train']
                    train_inputs = train_inputs.to(device=self.args.device)
                    train_targets = train_targets.to(device=self.args.device)
                    if train_inputs.size(2) == 1:
                        train_inputs = train_inputs.repeat(1, 1, 3, 1, 1)
                    train_embeddings = self.model(train_inputs, domain_id)
                    test_inputs, test_targets = batch['test']
                    test_inputs = test_inputs.to(device=self.args.device)
                    test_targets = test_targets.to(device=self.args.device)
                    if test_inputs.size(2) == 1:
                        test_inputs = test_inputs.repeat(1, 1, 3, 1, 1)
                    test_embeddings = self.model(test_inputs, domain_id)

                    prototypes = get_prototypes(train_embeddings,
                                                train_targets, args.num_way)
                    loss = prototypical_loss(prototypes, test_embeddings,
                                             test_targets)
                    loss.backward(retain_graph=True)

                    param_list = []
                    param_names = []
                    for name, v in self.model.named_parameters():
                        if 'domain_out' not in name:
                            if v.requires_grad:
                                param_list.append(v)
                                param_names.append(name)
                    first_grad = torch.autograd.grad(loss,
                                                     param_list,
                                                     create_graph=False,
                                                     retain_graph=False)

                    val_graddict = {}
                    layer_name = []
                    for gradient, name in zip(first_grad, param_names):
                        split_name = name.split('.')
                        layer = split_name[0]
                        if layer not in self.args.layer_filters:
                            if layer not in layer_name:
                                layer_name.append(layer)
                                val_graddict[layer] = []
                                val_graddict[layer].append(
                                    gradient.clone().view(-1))
                            else:
                                val_graddict[layer].append(
                                    gradient.clone().view(-1))
                        else:
                            layer_sub = layer + '.' + split_name[
                                1] + '.' + split_name[2]
                            if layer_sub not in layer_name:
                                layer_name.append(layer_sub)
                                val_graddict[layer_sub] = []
                                val_graddict[layer_sub].append(
                                    gradient.clone().view(-1))
                            else:
                                val_graddict[layer_sub].append(
                                    gradient.clone().view(-1))

                    for key in val_graddict:
                        val_graddict[key] = torch.cat(val_graddict[key])
                    self.optimizer.step()

                    if memory_train:
                        memory_trainnew = copy.deepcopy(memory_train)
                        self.hyper_optim.optimizer = self.optimizer
                        self.hyper_optim.compute_hg(self.model, val_graddict)
                        val_grad = self.rep_grad_new(self.args,
                                                     memory_trainnew)

                        self.hyper_optim.hyper_step(val_grad)
                        self.model.zero_grad()

                    if batch_idx >= args.num_batches:
                        break
Пример #8
0
    def rep_grad_new(self, args, memory_train):
        memory_loss = 0
        for dataidx, dataloader_dict in enumerate(memory_train):
            for dataname, memory_list in dataloader_dict.items():
                select = random.choice(memory_list)
                memory_train_inputs, memory_train_targets = select['train']
                memory_train_inputs = memory_train_inputs.to(
                    device=args.device)
                memory_train_targets = memory_train_targets.to(
                    device=args.device)
                if memory_train_inputs.size(2) == 1:
                    memory_train_inputs = memory_train_inputs.repeat(
                        1, 1, 3, 1, 1)
                memory_train_embeddings = self.model(memory_train_inputs,
                                                     dataidx)

                memory_test_inputs, memory_test_targets = select['test']
                memory_test_inputs = memory_test_inputs.to(device=args.device)
                memory_test_targets = memory_test_targets.to(
                    device=args.device)
                if memory_test_inputs.size(2) == 1:
                    memory_test_inputs = memory_test_inputs.repeat(
                        1, 1, 3, 1, 1)

            indlist = []
            for ind in range(len(memory_train)):
                if ind != dataidx:
                    indlist.append(ind)

            if indlist:
                indselect = random.choice(indlist)
                dataloader_dict2 = memory_train[indselect]
                for dataname, memory_list in dataloader_dict2.items():
                    select2 = random.choice(memory_list)
                    memory_train_inputs2, memory_train_targets2 = select2[
                        'train']
                    memory_train_inputs2 = memory_train_inputs2.to(
                        device=args.device)
                    memory_train_targets2 = memory_train_targets2.to(
                        device=args.device)
                    if memory_train_inputs2.size(2) == 1:
                        memory_train_inputs2 = memory_train_inputs2.repeat(
                            1, 1, 3, 1, 1)
                    memory_train_embeddings2 = self.model(
                        memory_train_inputs2, dataidx)

                    memory_test_inputs2, memory_test_targets2 = select2['test']
                    memory_test_inputs2 = memory_test_inputs2.to(
                        device=args.device)
                    memory_test_targets2 = memory_test_targets2.to(
                        device=args.device)
                    if memory_test_inputs2.size(2) == 1:
                        memory_test_inputs2 = memory_test_inputs2.repeat(
                            1, 1, 3, 1, 1)

                memory_test_embeddings2 = self.model(memory_test_inputs2,
                                                     dataidx)
            memory_test_embeddings = self.model(memory_test_inputs, dataidx)
            memory_prototypes = get_prototypes(memory_train_embeddings,
                                               memory_train_targets,
                                               args.num_way)
            memory_loss += prototypical_loss(memory_prototypes,
                                             memory_test_embeddings,
                                             memory_test_targets)

            if indlist:
                sellist = range(len(memory_train_embeddings))
                i = random.choice(sellist)
                memory_train_embeddings = torch.cat([
                    memory_train_embeddings[i].unsqueeze(0),
                    memory_train_embeddings2[i].unsqueeze(0)
                ],
                                                    dim=1)
                memory_train_targets = torch.cat([
                    memory_train_targets[i].unsqueeze(0),
                    memory_train_targets2[i].unsqueeze(0) + args.num_way
                ],
                                                 dim=1)
                memory_prototypes = get_prototypes(memory_train_embeddings,
                                                   memory_train_targets,
                                                   2 * args.num_way)
                memory_test_embeddings = torch.cat([
                    memory_test_embeddings[i].unsqueeze(0),
                    memory_test_embeddings2[i].unsqueeze(0)
                ],
                                                   dim=1)
                memory_test_targets = torch.cat([
                    memory_test_targets[i].unsqueeze(0),
                    memory_test_targets2[i].unsqueeze(0) + args.num_way
                ],
                                                dim=1)
                memory_loss += 1e-4 * prototypical_loss(
                    memory_prototypes, memory_test_embeddings,
                    memory_test_targets)

        param_list = []
        param_names = []
        for name, v in self.model.named_parameters():
            if 'domain_out' not in name:
                if v.requires_grad:
                    param_list.append(v)
                    param_names.append(name)
        val_grad = torch.autograd.grad(memory_loss, param_list)

        val_graddict = {}
        layer_name = []
        for gradient, name in zip(val_grad, param_names):
            split_name = name.split('.')
            layer = split_name[0]
            if layer not in self.args.layer_filters:
                if layer not in layer_name:
                    layer_name.append(layer)
                    val_graddict[layer] = []
                    val_graddict[layer].append(gradient.view(-1))
                else:
                    val_graddict[layer].append(gradient.view(-1))
            else:
                layer_sub = layer + '.' + split_name[1] + '.' + split_name[2]
                if layer_sub not in layer_name:
                    layer_name.append(layer_sub)
                    val_graddict[layer_sub] = []
                    val_graddict[layer_sub].append(gradient.view(-1))
                else:
                    val_graddict[layer_sub].append(gradient.view(-1))

        for key in val_graddict:
            val_graddict[key] = torch.cat(val_graddict[key])
        self.model.zero_grad()
        memory_loss.detach_()
        return val_graddict
Пример #9
0
            global_task_count += args.batch_tasks

            support_inputs, support_targets, support_semantics = get_inputs_and_outputs(
                args, train_batch['train'])
            query_inputs, query_targets, _ = get_inputs_and_outputs(
                args, train_batch['test'])

            support_embeddings, ca_weights, sa_weights = model(
                support_inputs,
                semantics=support_semantics,
                output_weights=True)

            query_embeddings = model(query_inputs)

            prototypes = get_prototypes(support_embeddings, support_targets,
                                        train_dataset.num_classes_per_task)

            train_loss = prototypical_loss(prototypes, query_embeddings,
                                           query_targets)
            train_acc = get_proto_accuracy(prototypes, query_embeddings,
                                           query_targets)
            del ca_weights, sa_weights

            optimizer.zero_grad()
            train_loss.backward()
            optimizer.step()

            pbar.set_postfix(train_acc='{0:.4f}'.format(train_acc.item()))

            # validation
            if global_task_count % args.valid_every_tasks == 0:
Пример #10
0
def train(args):

    datasets = []
    for root_idx in tqdm(range(args.ds_size)):

        folders = {
            "index":
            root_idx,
            "cache":
            "/nrs/funke/wolfs2/lisl/datasets/prototypical_network_cache_uncompressed_fg2.zarr",
            "raw": ("/nrs/funke/wolfs2/lisl/datasets/dsb_indexed.zarr",
                    f"train/raw/{root_idx}"),
            "gt_segmentation":
            ("/nrs/funke/wolfs2/lisl/datasets/dsb_indexed.zarr",
             f"train/gt_segmentation/{root_idx}"),
            "embedding":
            (("/nrs/funke/wolfs2/lisl/experiments/semantic/c32/prediction/anchor.zarr",
              f"train/prediction_interm/{root_idx}"),
             ("/nrs/funke/wolfs2/lisl/experiments/semantic/c32/prediction/semantic.zarr",
              f"train/prediction/{root_idx}")),
            "min_samples":
            10,
            "bg_distance":
            20
        }

        ds = fast_dataset_creator(folders,
                                  shots=args.num_shots,
                                  ways=args.num_ways,
                                  shuffle=True,
                                  test_shots=3,
                                  transform=None,
                                  meta_train=True)

        if len(ds):
            datasets.append(ds)
        else:
            print(
                f"dataset with id {root_idx} appears to be empty. Will be skipped"
            )

    os.makedirs(args.output_folder, exist_ok=True)

    loaders = []
    for ds in datasets:
        loaders.append(
            BatchMetaDataLoader(ds,
                                batch_size=args.batch_size,
                                shuffle=True,
                                num_workers=args.num_workers))
        time.sleep(1)
        print("ds loaded")

    model = PrototypicalNetwork(544,
                                args.inst_embedding_size,
                                2,
                                hidden_size=args.hidden_size)
    model.to(device=args.device)
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    print("starting")

    accuracy_window = {}
    accuracy_window["inst_accuracy"] = []
    accuracy_window["sem_accuracy"] = []
    accuracy_window["combined"] = []

    # Training loop
    for epoch in range(500):

        dataloader = roundrobin_break_early(*loaders)

        with tqdm(dataloader, total=args.num_batches) as pbar:
            for batch_idx, batch in enumerate(pbar):
                model.zero_grad()

                train_inputs, train_instance_targets, train_semantic_targets = batch[
                    'train']
                train_inputs = train_inputs.to(device=args.device)
                train_instance_targets = train_instance_targets.to(
                    device=args.device)
                train_semantic_targets = train_semantic_targets.to(
                    device=args.device)

                train_semantic_embeddings, train_spatial_instance_embeddings = model(
                    train_inputs)

                test_inputs, test_instance_targets, test_semantic_targets = batch[
                    'test']
                test_inputs = test_inputs.to(device=args.device)
                test_instance_targets = test_instance_targets.to(
                    device=args.device)
                test_semantic_targets = test_semantic_targets.to(
                    device=args.device)
                test_semantic_embeddings, test_spatial_instance_embeddings = model(
                    test_inputs)

                # semantic loss
                c = train_semantic_embeddings.shape[-1]
                semantic_loss = F.cross_entropy(
                    train_semantic_embeddings.view(-1, c),
                    train_semantic_targets.view(-1))

                # semantic_prototypes = get_prototypes(train_semantic_embeddings, train_semantic_targets, 2)
                # semantic_loss = prototypical_loss(semantic_prototypes, test_semantic_embeddings, test_semantic_targets)

                # instance loss
                train_inst_emb = train_spatial_instance_embeddings
                test_inst_emb = test_spatial_instance_embeddings
                instance_prototypes = get_prototypes(train_inst_emb,
                                                     train_instance_targets,
                                                     args.num_ways)
                instance_loss = prototypical_loss(instance_prototypes,
                                                  test_inst_emb,
                                                  test_instance_targets)

                loss = semantic_loss + instance_loss

                loss.backward()
                optimizer.step()

                with torch.no_grad():
                    inst_accuracy = get_accuracy(instance_prototypes,
                                                 test_inst_emb,
                                                 test_instance_targets)
                    pred = test_semantic_embeddings.max(-1).indices
                    sem_accuracy = (pred
                                    == test_semantic_targets).sum().item() / (
                                        pred.size(0) * pred.size(1))
                    sem_accuracy = float(sem_accuracy)

                    if len(accuracy_window["inst_accuracy"]) > 100:
                        accuracy_window["inst_accuracy"].pop(0)
                        accuracy_window["sem_accuracy"].pop(0)
                        accuracy_window["combined"].pop(0)

                    accuracy_window["inst_accuracy"].append(
                        float(inst_accuracy))
                    accuracy_window["sem_accuracy"].append(float(sem_accuracy))
                    accuracy_window["combined"].append(
                        float(inst_accuracy + sem_accuracy) / 2)

                    mean_inst_acc = np.mean(accuracy_window['inst_accuracy'])
                    mean_sem_acc = np.mean(accuracy_window['sem_accuracy'])
                    pbar.set_postfix(inst_accuracy=f"{mean_inst_acc:.4f}",
                                     sem_accuracy=f'{mean_sem_acc:.4f}',
                                     loss=f"{loss:.4}")

                if pbar.n > args.num_batches:
                    break

        # Save model
        score = np.mean(accuracy_window["combined"])
        if args.output_folder is not None:
            filename = os.path.join(
                args.output_folder,
                f'prototypical_networks_class_{epoch}_{args.num_shots}shot_{args.num_ways}way_{args.ds_size}_sem_spatial_{score:.4f}.pytorch'
            )
            print("saving", filename)
            torch.save(model.state_dict(), filename)