Esempio n. 1
0
    def setup(self,
              datacrawler,
              mode='train',
              batch_size=32,
              instance=6,
              workers=8):
        """ Setup the data generator.

    Args:
      workers (int): Number of workers to use during data retrieval/loading
      datacrawler (VeRiDataCrawler): A DataCrawler object that has crawled the data directory
      mode (str): One of 'train', 'test', 'query'. 
    """
        if datacrawler is None:
            raise ValueError("Must pass DataCrawler instance. Passed `None`")
        self.workers = workers * self.gpus

        # If training, get images whose labels are 0-99
        # If testing, get images whose labels are 100-199
        if mode == "train":
            self.__dataset = TDataSet(
                datacrawler.metadata["train"]["crawl"] +
                datacrawler.metadata["test"]["crawl"], self.transformer,
                range(0, 100))
        elif mode == "train-gzsl":
            self.__dataset = TDataSet(datacrawler.metadata["train"]["crawl"],
                                      self.transformer, range(0, 100))
        elif mode == "zsl" or mode == "test":
            self.__dataset = TDataSet(datacrawler.metadata["query"]["crawl"],
                                      self.transformer, range(100, 200))
        elif mode == "gzsl":  # For the generalized zero shot learning mode
            self.__dataset = TDataSet(
                datacrawler.metadata["test"]["crawl"] +
                datacrawler.metadata["query"]["crawl"], self.transformer,
                range(0, 200))

        else:
            raise NotImplementedError()

        if mode == "train" or mode == "train-gzsl":
            self.dataloader = TorchDataLoader(self.__dataset, batch_size=batch_size*self.gpus, \
                                              shuffle=True, \
                                              num_workers=self.workers, drop_last=True, collate_fn=self.collate_simple)
            self.num_entities = 100
        elif mode == "zsl" or mode == "test":
            self.dataloader = TorchDataLoader(self.__dataset, batch_size=batch_size*self.gpus, \
                                              shuffle = False,
                                              num_workers=self.workers, drop_last=True, collate_fn=self.collate_simple)
            self.num_entities = 100
        elif mode == "gzsl":
            self.dataloader = TorchDataLoader(self.__dataset, batch_size=batch_size*self.gpus, \
                                              shuffle = False,
                                              num_workers=self.workers, drop_last=True, collate_fn=self.collate_simple)
            self.num_entities = 200

        else:
            raise NotImplementedError()
Esempio n. 2
0
    def setup(self,
              datacrawler,
              mode="train",
              batch_size=32,
              workers=8,
              preload_classes=[]):
        """ Setup the data generator.

        Args:
            workers (int): Number of workers to use during data retrieval/loading
            datacrawler (VeRiDataCrawler): A DataCrawler object that has crawled the data directory
            mode (str): One of 'train', 'test', 'query'. 
        """
        if datacrawler is None:
            raise ValueError("Must pass DataCrawler instance. Passed `None`")
        self.workers = workers * self.gpus

        train_mode = True if mode == "train" else False
        target_convert = None
        if datacrawler in ["MNIST", "CIFAR10", "CIFAR100"]:
            __dataset = getattr(torchvision.datasets, datacrawler)
            self.__dataset = __dataset(root="./" + datacrawler,
                                       train=train_mode,
                                       download=True,
                                       transform=self.transformer)
            # Add extra channel to MNIST
            if datacrawler == "MNIST":
                self.__dataset.data = self.__dataset.data.unsqueeze(3)
            # Need to handle issue where CIFAR10, CIFAR100 torch dataset downloaders load into list, instead of torch tensor
            if datacrawler in ["CIFAR10", "CIFAR100"]:
                if type(self.__dataset.targets).__name__ == "list":
                    target_convert = "list"
                    self.__dataset.targets = torch.Tensor(
                        self.__dataset.targets).int()
                #if type(self.__dataset.data).__name__ == "ndarray":
                #    pdb.set_trace()
                #self.__dataset.data = torch.from_numpy(self.__dataset.data).double()

            if len(preload_classes) > 0:
                valid_idxs = self.__dataset.targets == preload_classes[0]
                for _remaining in preload_classes[1:]:
                    valid_idxs += self.__dataset.targets == _remaining
                self.__dataset.targets = self.__dataset.targets[valid_idxs]
                self.__dataset.data = self.__dataset.data[valid_idxs]
            if target_convert == "list":
                self.__dataset.targets = self.__dataset.targets.tolist()
        else:
            raise NotImplementedError()

        self.dataloader = TorchDataLoader(self.__dataset,
                                          batch_size=batch_size * self.gpus,
                                          shuffle=True,
                                          num_workers=self.workers)
Esempio n. 3
0
    def setup(self,
              datacrawler,
              mode='train',
              batch_size=32,
              instance=8,
              workers=8):
        """ Setup the data generator.

    Args:
      workers (int): Number of workers to use during data retrieval/loading
      datacrawler (VeRiDataCrawler): A DataCrawler object that has crawled the data directory
      mode (str): One of 'train', 'test', 'query'. 
    """
        if datacrawler is None:
            raise ValueError("Must pass DataCrawler instance. Passed `None`")
        self.workers = workers * self.gpus

        if mode == "train":
            self.__dataset = TDataSet(datacrawler.metadata[mode]["crawl"],
                                      self.transformer)
        elif mode == "test":
            # For testing, we combine images in the query and testing set to generate batches
            self.__dataset = TDataSet(
                datacrawler.metadata[mode]["crawl"] +
                datacrawler.metadata["query"]["crawl"], self.transformer)
        else:
            raise NotImplementedError()

        if mode == "train":
            self.dataloader = TorchDataLoader(self.__dataset, batch_size=batch_size*self.gpus, \
                                              sampler = TSampler(datacrawler.metadata[mode]["crawl"], batch_size=batch_size*self.gpus, instance=instance*self.gpus), \
                                              num_workers=self.workers, collate_fn=self.collate_simple)
            self.num_entities = datacrawler.metadata[mode]["pids"]
        elif mode == "test":
            self.dataloader = TorchDataLoader(self.__dataset, batch_size=batch_size*self.gpus, \
                                              shuffle = False,
                                              num_workers=self.workers, collate_fn=self.collate_with_camera)
            self.num_entities = len(datacrawler.metadata["query"]["crawl"])
        else:
            raise NotImplementedError()
def train_models() -> None:
    """
    # Start with smooth unit hypersphere -> We cover uniformly the space and randomly assign the features representation to the input image.
    # At every iteration we solve the credit assignment problem for each mini batch
    # We compute the loss between each input image and all the targets in the minibatch
    # Then we treat this as credit assignment problem -> we reassign the targets to the images in the mini batch s.t. distance between current representation and target
    # is minimized
    # we shuffle the dataset every epoch (Super Important)  -> Required to ensure each minibatch gets different image / target as learning goes.
    # As network learns the targets close in space should get assign to similar images.

    :return: Nothing
    """

    parser = initialise_arg_parser()
    args = parser.parse_args()

    checkpoint = None
    if args.resume:
        print("=> loading checkpoint '{}'".format(args.resume))
        checkpoint = get_checkpoint(args.resume)
        if checkpoint is not None:
            args = restore_args(args, checkpoint)
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    im_grayscale = not args.im_color
    im_gradients = not args.no_gradients

    encoder, decoder = initialise_resnet(use_grayscale=im_grayscale,
                                         use_im_gradients=im_gradients,
                                         im_shape=im_shape,
                                         num_classes=num_classes)

    train_dir = os.path.join(args.data_dir, 'train')
    val_dir = os.path.join(args.data_dir, 'val')

    train_transforms, val_transforms = initialise_transforms(args)

    trainset = NatImageFolder(
        root=train_dir,
        z_dims=encoder.get_output_shape(input_shape=im_shape),
        transform=train_transforms)

    dl_trainset = DataLoader(trainset,
                             batch_size=args.batch_size,
                             shuffle=True,
                             num_workers=args.workers,
                             pin_memory=use_cuda)

    # end to end eval for encoder / decoder pair.
    val_set = ImageFolder(val_dir, transform=val_transforms)

    dl_val_set = TorchDataLoader(val_set,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=args.workers,
                                 pin_memory=use_cuda)

    writer = SummaryWriter(log_dir=args.log_dir)

    # Freeze first layer (Compute image gradient from grayscale).
    encoder_parameters = encoder.features.parameters()

    if use_cuda:
        encoder.cuda()
        decoder.cuda()

    encoder_loss_fn = nn.MSELoss()
    decoder_loss_fn = nn.CrossEntropyLoss()

    encoder_optim = torch.optim.Adam(encoder_parameters,
                                     lr=args.lr,
                                     weight_decay=args.weight_decay)
    decoder_optim = torch.optim.Adam(decoder.parameters(),
                                     lr=args.lr,
                                     weight_decay=args.weight_decay)

    # Lr schedule in paper constant decay until t_0 then l_0 / (1+gamma*(t-t_0)+)
    # permutation every 3 epochs
    # do the freeze + learning of classifier every 20 epochs
    scheduler = torch.optim.lr_scheduler.MultiStepLR(encoder_optim,
                                                     milestones=[10, 20],
                                                     gamma=0.5)

    if use_cuda:
        decoder_loss_fn = decoder_loss_fn.cuda()
        encoder_loss_fn = encoder_loss_fn.cuda()

    best_acc = 0.0

    if checkpoint is not None:
        best_acc = checkpoint['best_acc']
        encoder.load_state_dict(checkpoint['encoder_state_dict'])
        decoder.load_state_dict(checkpoint['decoder_state_dict'])

        encoder_optim.load_state_dict(checkpoint['encoder_optim'])
        decoder_optim.load_state_dict(checkpoint['decoder_optim'])
        print(
            "=> Successfully restored All model parameters. Restarting from epoch: {}"
            .format(args.current_epoch))

    for epoch in trange(args.current_epoch,
                        args.epochs,
                        desc='epochs: ',
                        leave=False):
        scheduler.step(epoch)

        update_targets = bool(((epoch + 1) % args.ut) == 0)
        train_decoder = bool(((epoch + 1) % args.td) == 0)

        # set encoder to training mode. (for batch norm layers)
        encoder.train(True)
        decoder.train(train_decoder)

        # Stream Training dataset with NAT
        for batch_idx, (idx, x, y, nat) in enumerate(tqdm(dl_trainset, 0), 1):
            e_targets = nat.numpy()
            if use_cuda:
                x, y, nat = x.cuda(), y.cuda(), nat.cuda()

            x = Variable(x)
            encoder_optim.zero_grad()
            outputs = encoder(x)

            # every few iterations greedy re-assign targets.
            if update_targets:
                e_out = outputs.cpu().data.numpy()
                new_targets = calc_optimal_target_permutation(e_out, e_targets)
                # update.
                trainset.update_targets(idx, new_targets)
                nat = torch.FloatTensor(new_targets)
                if use_cuda:
                    nat = nat.cuda()

            # train encoder
            nat = Variable(nat)
            encoder_loss = encoder_loss_fn(outputs, nat)
            encoder_loss.backward(retain_graph=True)
            encoder_optim.step()

            if batch_idx % 100 == 0:
                writer.add_scalar('encoder_loss', encoder_loss.data[0],
                                  int((epoch + 1) * (batch_idx / 100)))

            if train_decoder:
                y = Variable(y)
                decoder_optim.zero_grad()
                y_pred = decoder(outputs)
                decoder_loss = decoder_loss_fn(y_pred, y)
                decoder_loss.backward()
                decoder_optim.step()

                if batch_idx % 100 == 0:
                    idx_step = int(((epoch + 1) / args.td) * (batch_idx / 100))
                    writer.add_scalar('decoder_loss', decoder_loss.data[0],
                                      idx_step)

        # Writer weight + gradient histogram for each epoch
        for name, param in encoder.named_parameters():
            name = 'encoder/' + name.replace('.', '/')
            writer.add_histogram(name,
                                 param.clone().cpu().data.numpy(), (epoch + 1))
            if param.grad is not None:
                writer.add_histogram(name + '/grad',
                                     param.grad.clone().cpu().data.numpy(),
                                     (epoch + 1))

        # if decoder has been trained, eval classifier
        if train_decoder:

            # write decoder weight + gradient once trained for 1 epoch.
            for name, param in decoder.named_parameters():
                name = 'decoder/' + name.replace('.', '/')
                writer.add_histogram(name,
                                     param.clone().cpu().data.numpy(),
                                     (epoch + 1))
                writer.add_histogram(name + '/grad',
                                     param.grad.clone().cpu().data.numpy(),
                                     (epoch + 1))

            # set models to eval mode and validate on test set.
            encoder.eval()
            decoder.eval()

            all_preds = np.empty(shape=(len(dl_val_set) * args.batch_size))
            all_truth = np.empty(shape=(len(dl_val_set) * args.batch_size))
            test_loss = 0.0

            for idx, (x, y) in tqdm(enumerate(dl_val_set, 0)):
                if use_cuda:
                    x, y = x.cuda(), y.cuda()
                x, y = Variable(x), Variable(y)

                y_pred = decoder(encoder(x))
                loss = decoder_loss_fn(y_pred, y)
                test_loss += loss.data[0]

                init_pos = idx * args.batch_size
                all_truth[init_pos:init_pos + len(y)] = y.cpu().data.numpy()
                y_hat = np.argmax(y_pred.cpu().data.numpy(), axis=1)
                all_preds[init_pos:init_pos + len(y)] = y_hat

            # compute test stats & add to tensorboard
            all_preds = all_preds.astype(np.int32)
            all_truth = all_truth.astype(np.int32)

            # acc score
            acc_score = accuracy_score(all_truth, all_preds)
            writer.add_scalar('accuracy_score', acc_score, (epoch + 1))

            if acc_score > best_acc:
                print(f'saving best encoder / decoder pair....')

                state = {
                    'args': {
                        'arch': args.arch,
                        'im_color': args.im_color,
                        'no_gradients': args.no_gradients
                    },
                    'encoder_state_dict': encoder.state_dict(),
                    'decoder_state_dict': decoder.state_dict(),
                    'encoder_optim': encoder_optim.state_dict(),
                    'decoder_optim': decoder_optim.state_dict(),
                    'best_acc': acc_score,
                    'epoch': epoch + 1,
                }
                if not os.path.isdir(args.checkpoint_dir):
                    os.mkdir(args.checkpoint_dir)
                torch.save(
                    state,
                    os.path.join(args.checkpoint_dir,
                                 f'chkpt_full_{epoch}.pkl'))
Esempio n. 5
0
    config = CINConfig()
    dataset_val = OOIDataset("val")
    val_set = Dataset(dataset_val, config)
    # train_set.__getitem__(5)
    # train_set.__getitem__(20)
    # print("dataset_train", dataset_train.num_images)
    # print("train_set", train_set.__len__())
    #
    def my_collate_fn(batch):
        batch = list(filter(lambda x: x is not None, batch))
        if len(batch) == 0:
            print("No valid data!!!")
            batch = [[torch.from_numpy(np.zeros([1, 1]))]]
        return default_collate(batch)

    val_generator = TorchDataLoader(val_set, collate_fn=my_collate_fn, batch_size=1, shuffle=True, num_workers=1)
    step = 0
    for inputs in val_generator:
        if len(inputs) != 17:
            print("length of inputs", len(inputs))
            print("inputs", inputs)
            continue
        print(str(step)+"/9000")
        step+=1

    '''
    result=load_image_gt(ooiDataset,OISMPSConfig(),9,True,True)
    thing_masks=result[4]
    stuff_masks=result[7]
    semantic_label=result[8]
    influence_mask=result[11]