Exemple #1
0
    def _encoder_model(self, input_shape, hyperparameters):
        squeezenet = SqueezeNet(
            input_shape=(self.input_shape[0], self.input_shape[1], 3),
            include_top=False,
        )
        x = Flatten()(squeezenet.output)
        embedding = Dense(np.prod(hyperparameters['embedding_dim']), activation='relu')(x)

        encoder = Model(squeezenet.input, embedding)
        utils.freeze_layers(squeezenet)
        return encoder
Exemple #2
0
    def fit(self, x_texts, y_train, validation_split=0, epochs=1):
        utils.freeze_layers(self.language_model)
        # Freeze language model
        y_train = utils.one_hot_encode(y_train, self.num_classes)
        x_train = self.vectorize_text(x_texts)
        if validation_split > 0.:
            callbacks = [EarlyStopping(patience=3)]
        else:
            callbacks = None

        with self.graph.as_default():
            history = self.classifier.fit(
                x_train,
                y_train,
                validation_split=validation_split,
                callbacks=callbacks,
                epochs=epochs,
                verbose=0,
            )
            return history.history['loss'][0], history.history['acc'][0]
def main(args):
    utils.set_seed(args.seed)

    my_experiment = experiment(args.name,
                               args,
                               "results/",
                               commit_changes=args.commit)
    writer = SummaryWriter(my_experiment.path + "tensorboard")

    logger = logging.getLogger('experiment')

    # Using first 963 classes of the omniglot as the meta-training set
    args.classes = list(range(963))

    args.traj_classes = list(range(int(963 / 2), 963))

    dataset = df.DatasetFactory.get_dataset(args.dataset,
                                            background=True,
                                            train=True,
                                            all=True)
    dataset_test = df.DatasetFactory.get_dataset(args.dataset,
                                                 background=True,
                                                 train=False,
                                                 all=True)

    # Iterators used for evaluation
    iterator_test = torch.utils.data.DataLoader(dataset_test,
                                                batch_size=5,
                                                shuffle=True,
                                                num_workers=1)

    iterator_train = torch.utils.data.DataLoader(dataset,
                                                 batch_size=5,
                                                 shuffle=True,
                                                 num_workers=1)

    sampler = ts.SamplerFactory.get_sampler(args.dataset, args.classes,
                                            dataset, dataset_test)

    config = mf.ModelFactory.get_model("na", args.dataset)

    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    maml = MetaLearingClassification(args, config).to(device)

    utils.freeze_layers(args.rln, maml)

    for step in range(args.steps):

        t1 = np.random.choice(args.traj_classes, args.tasks, replace=False)

        d_traj_iterators = []
        for t in t1:
            d_traj_iterators.append(sampler.sample_task([t]))

        d_rand_iterator = sampler.get_complete_iterator()

        x_spt, y_spt, x_qry, y_qry = maml.sample_training_data(
            d_traj_iterators,
            d_rand_iterator,
            steps=args.update_step,
            reset=not args.no_reset)
        if torch.cuda.is_available():
            x_spt, y_spt, x_qry, y_qry = x_spt.cuda(), y_spt.cuda(
            ), x_qry.cuda(), y_qry.cuda()

        accs, loss = maml(x_spt, y_spt, x_qry, y_qry)

        # Evaluation during training for sanity checks
        if step % 40 == 39:
            writer.add_scalar('/metatrain/train/accuracy', accs[-1], step)
            logger.info('step: %d \t training acc %s', step, str(accs))
        if step % 300 == 299:
            utils.log_accuracy(maml, my_experiment, iterator_test, device,
                               writer, step)
            utils.log_accuracy(maml, my_experiment, iterator_train, device,
                               writer, step)
Exemple #4
0
def main(args):
    utils.set_seed(args.seed)

    my_experiment = experiment(args.name,
                               args,
                               "../results/",
                               commit_changes=args.commit)
    writer = SummaryWriter(my_experiment.path + "tensorboard")

    logger = logging.getLogger('experiment')

    # Using first 963 classes of the omniglot as the meta-training set
    args.classes = list(range(963))

    args.traj_classes = list(range(963))
    #
    dataset = df.DatasetFactory.get_dataset(args.dataset,
                                            background=True,
                                            train=True,
                                            all=True)
    dataset_test = df.DatasetFactory.get_dataset(args.dataset,
                                                 background=False,
                                                 train=True,
                                                 all=True)

    sampler = ts.SamplerFactory.get_sampler(args.dataset, args.classes,
                                            dataset, dataset)

    sampler_test = ts.SamplerFactory.get_sampler(args.dataset,
                                                 list(range(600)),
                                                 dataset_test, dataset_test)

    config = mf.ModelFactory.get_model("na", "omniglot-fc")

    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    maml = MetaLearingClassification(args, config).to(device)

    utils.freeze_layers(args.rln, maml)

    for step in range(args.steps):

        t1 = np.random.choice(args.traj_classes, args.tasks, replace=False)

        d_traj_iterators = []
        for t in t1:
            d_traj_iterators.append(sampler.sample_task([t]))

        d_rand_iterator = sampler.get_complete_iterator()

        x_spt, y_spt, x_qry, y_qry = maml.sample_few_shot_training_data(
            d_traj_iterators,
            d_rand_iterator,
            steps=args.update_step,
            reset=not args.no_reset)
        if torch.cuda.is_available():
            x_spt, y_spt, x_qry, y_qry = x_spt.cuda(), y_spt.cuda(
            ), x_qry.cuda(), y_qry.cuda()

        accs, loss = maml(x_spt, y_spt, x_qry, y_qry)
        #
        # Evaluation during training for sanity checks
        if step % 20 == 0:
            writer.add_scalar('/metatrain/train/accuracy', accs[-1], step)
            logger.info('step: %d \t training acc %s', step, str(accs))
            logger.info("Loss = %s", str(loss[-1].item()))
        if step % 600 == 599:
            torch.save(maml.net, my_experiment.path + "learner.model")
            accs_avg = None
            for temp_temp in range(0, 40):
                t1_test = np.random.choice(list(range(600)),
                                           args.tasks,
                                           replace=False)

                d_traj_test_iterators = []
                for t in t1_test:
                    d_traj_test_iterators.append(sampler_test.sample_task([t]))

                x_spt, y_spt, x_qry, y_qry = maml.sample_few_shot_training_data(
                    d_traj_test_iterators,
                    None,
                    steps=args.update_step,
                    reset=not args.no_reset)
                if torch.cuda.is_available():
                    x_spt, y_spt, x_qry, y_qry = x_spt.cuda(), y_spt.cuda(
                    ), x_qry.cuda(), y_qry.cuda()

                accs, loss = maml.finetune(x_spt, y_spt, x_qry, y_qry)
                if accs_avg is None:
                    accs_avg = accs
                else:
                    accs_avg += accs
            logger.info("Loss = %s", str(loss[-1].item()))
            writer.add_scalar('/metatest/train/accuracy', accs_avg[-1] / 40,
                              step)
            logger.info('TEST: step: %d \t testing acc %s', step,
                        str(accs_avg / 40))
def main(args):
    utils.set_seed(args.seed)

    my_experiment = experiment(args.name,
                               args,
                               "../results/",
                               commit_changes=args.commit)
    writer = SummaryWriter(my_experiment.path + "tensorboard")

    logger = logging.getLogger("experiment")

    # Using first 963 classes of the omniglot as the meta-training set
    args.classes = list(range(963))

    if torch.cuda.is_available():
        device = torch.device("cuda")
        use_cuda = True
    else:
        device = torch.device("cpu")
        use_cuda = False

    dataset = df.DatasetFactory.get_dataset(
        args.dataset,
        background=True,
        train=True,
        all=True,
        prefetch_gpu=args.prefetch_gpu,
        device=device,
    )
    dataset_test = dataset
    # dataset_test = df.DatasetFactory.get_dataset(
    #     args.dataset, background=True, train=False, all=True
    # )

    # Iterators used for evaluation
    iterator_test = torch.utils.data.DataLoader(dataset_test,
                                                batch_size=5,
                                                shuffle=True,
                                                num_workers=1)

    iterator_train = torch.utils.data.DataLoader(dataset,
                                                 batch_size=5,
                                                 shuffle=True,
                                                 num_workers=1)

    sampler = ts.SamplerFactory.get_sampler(
        args.dataset,
        args.classes,
        dataset,
        dataset_test,
        prefetch_gpu=args.prefetch_gpu,
        use_cuda=use_cuda,
    )

    config = mf.ModelFactory.get_model(args.treatment, args.dataset)

    maml = MetaLearingClassification(args, config, args.treatment).to(device)

    if args.checkpoint:
        checkpoint = torch.load(args.saved_model, map_location="cpu")

        for idx in range(len(checkpoint)):
            maml.net.parameters()[idx].data = checkpoint.parameters()[idx].data

    maml = maml.to(device)

    utils.freeze_layers(args.rln, maml)

    for step in range(args.steps):

        t1 = np.random.choice(args.classes, args.tasks, replace=False)

        d_traj_iterators = []
        for t in t1:
            d_traj_iterators.append(sampler.sample_task([t]))

        d_rand_iterator = sampler.get_complete_iterator()

        x_spt, y_spt, x_qry, y_qry = maml.sample_training_data(
            d_traj_iterators,
            d_rand_iterator,
            steps=args.update_step,
            reset=not args.no_reset,
        )
        if torch.cuda.is_available():
            x_spt, y_spt, x_qry, y_qry = (
                x_spt.cuda(),
                y_spt.cuda(),
                x_qry.cuda(),
                y_qry.cuda(),
            )

        accs, loss = maml(x_spt, y_spt, x_qry, y_qry)  # , args.tasks)

        # Evaluation during training for sanity checks
        if step % 40 == 0:
            # writer.add_scalar('/metatrain/train/accuracy', accs, step)
            logger.info("step: %d \t training acc %s", step, str(accs))
        if step % 100 == 0 or step == 19999:
            torch.save(maml.net, args.model_name)
        if step % 2000 == 0 and step != 0:
            utils.log_accuracy(maml, my_experiment, iterator_test, device,
                               writer, step)
            utils.log_accuracy(maml, my_experiment, iterator_train, device,
                               writer, step)
Exemple #6
0
def main(args):
    # Placeholder variables
    old_accs = [0]
    old_meta_losses = [2.**30, 0]

    utils.set_seed(args.seed)

    my_experiment = experiment(args.name,
                               args,
                               "./results/",
                               commit_changes=args.commit)
    writer = SummaryWriter(my_experiment.path + "tensorboard")
    logger = logging.getLogger('experiment')

    # Using first 963 classes of the omniglot as the meta-training set
    args.classes = list(range(963))
    args.traj_classes = list(range(int(963 / 2), 963))

    dataset = df.DatasetFactory.get_dataset(args.dataset,
                                            background=True,
                                            train=True,
                                            all=True)
    dataset_test = df.DatasetFactory.get_dataset(args.dataset,
                                                 background=True,
                                                 train=False,
                                                 all=True)
    # print("ONE ITEM", len(dataset.__getitem__(0)),dataset.__getitem__(0)[0].shape,dataset.__getitem__(0)[1])
    # Iterators used for evaluation
    iterator_test = torch.utils.data.DataLoader(dataset_test,
                                                batch_size=5,
                                                shuffle=True,
                                                num_workers=1)
    iterator_train = torch.utils.data.DataLoader(dataset,
                                                 batch_size=5,
                                                 shuffle=True,
                                                 num_workers=1)
    sampler = ts.SamplerFactory.get_sampler(args.dataset, args.classes,
                                            dataset, dataset_test)
    # print("NUM CLASSES",args.classes)
    config = mf.ModelFactory.get_model("na", args.dataset)

    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    # maml = MetaLearingClassification(args, config).to(device)
    maml = MetaLearingClassification(args, config).to(device)
    utils.freeze_layers(args.rln, maml)  # freeze layers

    for step in range(args.steps):  #epoch
        print("STEP: ", step)
        t1 = np.random.choice(args.traj_classes, args.tasks,
                              replace=False)  #sample sine waves
        # print("TRAJ CLASSES<",args.tasks)
        d_traj_iterators = []
        for t in t1:
            d_traj_iterators.append(sampler.sample_task([t]))
        # print("ANNOYINGNESS",d_traj_iterators)
        d_rand_iterator = sampler.get_complete_iterator()

        # Sample trajectory and random batch (support and query)
        x_spt, y_spt, x_qry, y_qry = maml.sample_training_data(
            d_traj_iterators,
            d_rand_iterator,
            steps=args.update_step,
            reset=not args.no_reset)
        if torch.cuda.is_available():
            x_spt, y_spt, x_qry, y_qry = x_spt.cuda(), y_spt.cuda(
            ), x_qry.cuda(), y_qry.cuda()

        # One training loop
        accs, loss = maml(x_spt, y_spt, x_qry, y_qry, step, old_accs,
                          old_meta_losses, args, config)

        # if loss[-2] >= old_meta_losses[-2]: #if training improves it,
        #     maml.set_self(other.get_self_state_dict())
        #     old_meta_losses = loss

        # else: #if not improved
        #     other.set_self(maml.get_self_state_dict())

        # Evaluation during training for sanity checks
        if step % 40 == 39:
            writer.add_scalar('/metatrain/train/accuracy', accs[-1], step)
            logger.info('step: %d \t training acc %s', step, str(accs))
        if step % 300 == 299:
            utils.log_accuracy(maml, my_experiment, iterator_test, device,
                               writer, step)
            utils.log_accuracy(maml, my_experiment, iterator_train, device,
                               writer, step)

        torch.save(maml.net, my_experiment.path + "omniglot_classifier.model")
def train(p, train_loader, model, optimizer, epoch, amp):
    losses = AverageMeter('Loss', ':.4e')
    contrastive_losses = AverageMeter('Contrastive', ':.4e')
    saliency_losses = AverageMeter('CE', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(
        len(train_loader),
        [losses, contrastive_losses, saliency_losses, top1, top5],
        prefix="Epoch: [{}]".format(epoch))
    model.train()

    if p['freeze_layers']:
        model = freeze_layers(model)

    for i, batch in enumerate(train_loader):
        # Forward pass
        im_q = batch['query']['image'].cuda(p['gpu'], non_blocking=True)
        im_k = batch['key']['image'].cuda(p['gpu'], non_blocking=True)
        sal_q = batch['query']['sal'].cuda(p['gpu'], non_blocking=True)
        sal_k = batch['key']['sal'].cuda(p['gpu'], non_blocking=True)

        logits, labels, saliency_loss = model(im_q=im_q,
                                              im_k=im_k,
                                              sal_q=sal_q,
                                              sal_k=sal_k)

        # Use E-Net weighting for calculating the pixel-wise loss.
        uniq, freq = torch.unique(labels, return_counts=True)
        p_class = torch.zeros(logits.shape[1],
                              dtype=torch.float32).cuda(p['gpu'],
                                                        non_blocking=True)
        p_class_non_zero_classes = freq.float() / labels.numel()
        p_class[uniq] = p_class_non_zero_classes
        w_class = 1 / torch.log(1.02 + p_class)
        contrastive_loss = cross_entropy(logits,
                                         labels,
                                         weight=w_class,
                                         reduction='mean')

        # Calculate total loss and update meters
        loss = contrastive_loss + saliency_loss
        contrastive_losses.update(contrastive_loss.item())
        saliency_losses.update(saliency_loss.item())
        losses.update(loss.item())

        acc1, acc5 = accuracy(logits, labels, topk=(1, 5))
        top1.update(acc1[0], im_q.size(0))
        top5.update(acc5[0], im_q.size(0))

        # Update model
        optimizer.zero_grad()
        if amp is not None:  # Mixed precision
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        optimizer.step()

        # Display progress
        if i % 25 == 0:
            progress.display(i)