def setup(self, datacrawler, mode='train', batch_size=32, instance=6, workers=8): """ Setup the data generator. Args: workers (int): Number of workers to use during data retrieval/loading datacrawler (VeRiDataCrawler): A DataCrawler object that has crawled the data directory mode (str): One of 'train', 'test', 'query'. """ if datacrawler is None: raise ValueError("Must pass DataCrawler instance. Passed `None`") self.workers = workers * self.gpus # If training, get images whose labels are 0-99 # If testing, get images whose labels are 100-199 if mode == "train": self.__dataset = TDataSet( datacrawler.metadata["train"]["crawl"] + datacrawler.metadata["test"]["crawl"], self.transformer, range(0, 100)) elif mode == "train-gzsl": self.__dataset = TDataSet(datacrawler.metadata["train"]["crawl"], self.transformer, range(0, 100)) elif mode == "zsl" or mode == "test": self.__dataset = TDataSet(datacrawler.metadata["query"]["crawl"], self.transformer, range(100, 200)) elif mode == "gzsl": # For the generalized zero shot learning mode self.__dataset = TDataSet( datacrawler.metadata["test"]["crawl"] + datacrawler.metadata["query"]["crawl"], self.transformer, range(0, 200)) else: raise NotImplementedError() if mode == "train" or mode == "train-gzsl": self.dataloader = TorchDataLoader(self.__dataset, batch_size=batch_size*self.gpus, \ shuffle=True, \ num_workers=self.workers, drop_last=True, collate_fn=self.collate_simple) self.num_entities = 100 elif mode == "zsl" or mode == "test": self.dataloader = TorchDataLoader(self.__dataset, batch_size=batch_size*self.gpus, \ shuffle = False, num_workers=self.workers, drop_last=True, collate_fn=self.collate_simple) self.num_entities = 100 elif mode == "gzsl": self.dataloader = TorchDataLoader(self.__dataset, batch_size=batch_size*self.gpus, \ shuffle = False, num_workers=self.workers, drop_last=True, collate_fn=self.collate_simple) self.num_entities = 200 else: raise NotImplementedError()
def setup(self, datacrawler, mode="train", batch_size=32, workers=8, preload_classes=[]): """ Setup the data generator. Args: workers (int): Number of workers to use during data retrieval/loading datacrawler (VeRiDataCrawler): A DataCrawler object that has crawled the data directory mode (str): One of 'train', 'test', 'query'. """ if datacrawler is None: raise ValueError("Must pass DataCrawler instance. Passed `None`") self.workers = workers * self.gpus train_mode = True if mode == "train" else False target_convert = None if datacrawler in ["MNIST", "CIFAR10", "CIFAR100"]: __dataset = getattr(torchvision.datasets, datacrawler) self.__dataset = __dataset(root="./" + datacrawler, train=train_mode, download=True, transform=self.transformer) # Add extra channel to MNIST if datacrawler == "MNIST": self.__dataset.data = self.__dataset.data.unsqueeze(3) # Need to handle issue where CIFAR10, CIFAR100 torch dataset downloaders load into list, instead of torch tensor if datacrawler in ["CIFAR10", "CIFAR100"]: if type(self.__dataset.targets).__name__ == "list": target_convert = "list" self.__dataset.targets = torch.Tensor( self.__dataset.targets).int() #if type(self.__dataset.data).__name__ == "ndarray": # pdb.set_trace() #self.__dataset.data = torch.from_numpy(self.__dataset.data).double() if len(preload_classes) > 0: valid_idxs = self.__dataset.targets == preload_classes[0] for _remaining in preload_classes[1:]: valid_idxs += self.__dataset.targets == _remaining self.__dataset.targets = self.__dataset.targets[valid_idxs] self.__dataset.data = self.__dataset.data[valid_idxs] if target_convert == "list": self.__dataset.targets = self.__dataset.targets.tolist() else: raise NotImplementedError() self.dataloader = TorchDataLoader(self.__dataset, batch_size=batch_size * self.gpus, shuffle=True, num_workers=self.workers)
def setup(self, datacrawler, mode='train', batch_size=32, instance=8, workers=8): """ Setup the data generator. Args: workers (int): Number of workers to use during data retrieval/loading datacrawler (VeRiDataCrawler): A DataCrawler object that has crawled the data directory mode (str): One of 'train', 'test', 'query'. """ if datacrawler is None: raise ValueError("Must pass DataCrawler instance. Passed `None`") self.workers = workers * self.gpus if mode == "train": self.__dataset = TDataSet(datacrawler.metadata[mode]["crawl"], self.transformer) elif mode == "test": # For testing, we combine images in the query and testing set to generate batches self.__dataset = TDataSet( datacrawler.metadata[mode]["crawl"] + datacrawler.metadata["query"]["crawl"], self.transformer) else: raise NotImplementedError() if mode == "train": self.dataloader = TorchDataLoader(self.__dataset, batch_size=batch_size*self.gpus, \ sampler = TSampler(datacrawler.metadata[mode]["crawl"], batch_size=batch_size*self.gpus, instance=instance*self.gpus), \ num_workers=self.workers, collate_fn=self.collate_simple) self.num_entities = datacrawler.metadata[mode]["pids"] elif mode == "test": self.dataloader = TorchDataLoader(self.__dataset, batch_size=batch_size*self.gpus, \ shuffle = False, num_workers=self.workers, collate_fn=self.collate_with_camera) self.num_entities = len(datacrawler.metadata["query"]["crawl"]) else: raise NotImplementedError()
def train_models() -> None: """ # Start with smooth unit hypersphere -> We cover uniformly the space and randomly assign the features representation to the input image. # At every iteration we solve the credit assignment problem for each mini batch # We compute the loss between each input image and all the targets in the minibatch # Then we treat this as credit assignment problem -> we reassign the targets to the images in the mini batch s.t. distance between current representation and target # is minimized # we shuffle the dataset every epoch (Super Important) -> Required to ensure each minibatch gets different image / target as learning goes. # As network learns the targets close in space should get assign to similar images. :return: Nothing """ parser = initialise_arg_parser() args = parser.parse_args() checkpoint = None if args.resume: print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = get_checkpoint(args.resume) if checkpoint is not None: args = restore_args(args, checkpoint) else: print("=> no checkpoint found at '{}'".format(args.resume)) im_grayscale = not args.im_color im_gradients = not args.no_gradients encoder, decoder = initialise_resnet(use_grayscale=im_grayscale, use_im_gradients=im_gradients, im_shape=im_shape, num_classes=num_classes) train_dir = os.path.join(args.data_dir, 'train') val_dir = os.path.join(args.data_dir, 'val') train_transforms, val_transforms = initialise_transforms(args) trainset = NatImageFolder( root=train_dir, z_dims=encoder.get_output_shape(input_shape=im_shape), transform=train_transforms) dl_trainset = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=use_cuda) # end to end eval for encoder / decoder pair. val_set = ImageFolder(val_dir, transform=val_transforms) dl_val_set = TorchDataLoader(val_set, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=use_cuda) writer = SummaryWriter(log_dir=args.log_dir) # Freeze first layer (Compute image gradient from grayscale). encoder_parameters = encoder.features.parameters() if use_cuda: encoder.cuda() decoder.cuda() encoder_loss_fn = nn.MSELoss() decoder_loss_fn = nn.CrossEntropyLoss() encoder_optim = torch.optim.Adam(encoder_parameters, lr=args.lr, weight_decay=args.weight_decay) decoder_optim = torch.optim.Adam(decoder.parameters(), lr=args.lr, weight_decay=args.weight_decay) # Lr schedule in paper constant decay until t_0 then l_0 / (1+gamma*(t-t_0)+) # permutation every 3 epochs # do the freeze + learning of classifier every 20 epochs scheduler = torch.optim.lr_scheduler.MultiStepLR(encoder_optim, milestones=[10, 20], gamma=0.5) if use_cuda: decoder_loss_fn = decoder_loss_fn.cuda() encoder_loss_fn = encoder_loss_fn.cuda() best_acc = 0.0 if checkpoint is not None: best_acc = checkpoint['best_acc'] encoder.load_state_dict(checkpoint['encoder_state_dict']) decoder.load_state_dict(checkpoint['decoder_state_dict']) encoder_optim.load_state_dict(checkpoint['encoder_optim']) decoder_optim.load_state_dict(checkpoint['decoder_optim']) print( "=> Successfully restored All model parameters. Restarting from epoch: {}" .format(args.current_epoch)) for epoch in trange(args.current_epoch, args.epochs, desc='epochs: ', leave=False): scheduler.step(epoch) update_targets = bool(((epoch + 1) % args.ut) == 0) train_decoder = bool(((epoch + 1) % args.td) == 0) # set encoder to training mode. (for batch norm layers) encoder.train(True) decoder.train(train_decoder) # Stream Training dataset with NAT for batch_idx, (idx, x, y, nat) in enumerate(tqdm(dl_trainset, 0), 1): e_targets = nat.numpy() if use_cuda: x, y, nat = x.cuda(), y.cuda(), nat.cuda() x = Variable(x) encoder_optim.zero_grad() outputs = encoder(x) # every few iterations greedy re-assign targets. if update_targets: e_out = outputs.cpu().data.numpy() new_targets = calc_optimal_target_permutation(e_out, e_targets) # update. trainset.update_targets(idx, new_targets) nat = torch.FloatTensor(new_targets) if use_cuda: nat = nat.cuda() # train encoder nat = Variable(nat) encoder_loss = encoder_loss_fn(outputs, nat) encoder_loss.backward(retain_graph=True) encoder_optim.step() if batch_idx % 100 == 0: writer.add_scalar('encoder_loss', encoder_loss.data[0], int((epoch + 1) * (batch_idx / 100))) if train_decoder: y = Variable(y) decoder_optim.zero_grad() y_pred = decoder(outputs) decoder_loss = decoder_loss_fn(y_pred, y) decoder_loss.backward() decoder_optim.step() if batch_idx % 100 == 0: idx_step = int(((epoch + 1) / args.td) * (batch_idx / 100)) writer.add_scalar('decoder_loss', decoder_loss.data[0], idx_step) # Writer weight + gradient histogram for each epoch for name, param in encoder.named_parameters(): name = 'encoder/' + name.replace('.', '/') writer.add_histogram(name, param.clone().cpu().data.numpy(), (epoch + 1)) if param.grad is not None: writer.add_histogram(name + '/grad', param.grad.clone().cpu().data.numpy(), (epoch + 1)) # if decoder has been trained, eval classifier if train_decoder: # write decoder weight + gradient once trained for 1 epoch. for name, param in decoder.named_parameters(): name = 'decoder/' + name.replace('.', '/') writer.add_histogram(name, param.clone().cpu().data.numpy(), (epoch + 1)) writer.add_histogram(name + '/grad', param.grad.clone().cpu().data.numpy(), (epoch + 1)) # set models to eval mode and validate on test set. encoder.eval() decoder.eval() all_preds = np.empty(shape=(len(dl_val_set) * args.batch_size)) all_truth = np.empty(shape=(len(dl_val_set) * args.batch_size)) test_loss = 0.0 for idx, (x, y) in tqdm(enumerate(dl_val_set, 0)): if use_cuda: x, y = x.cuda(), y.cuda() x, y = Variable(x), Variable(y) y_pred = decoder(encoder(x)) loss = decoder_loss_fn(y_pred, y) test_loss += loss.data[0] init_pos = idx * args.batch_size all_truth[init_pos:init_pos + len(y)] = y.cpu().data.numpy() y_hat = np.argmax(y_pred.cpu().data.numpy(), axis=1) all_preds[init_pos:init_pos + len(y)] = y_hat # compute test stats & add to tensorboard all_preds = all_preds.astype(np.int32) all_truth = all_truth.astype(np.int32) # acc score acc_score = accuracy_score(all_truth, all_preds) writer.add_scalar('accuracy_score', acc_score, (epoch + 1)) if acc_score > best_acc: print(f'saving best encoder / decoder pair....') state = { 'args': { 'arch': args.arch, 'im_color': args.im_color, 'no_gradients': args.no_gradients }, 'encoder_state_dict': encoder.state_dict(), 'decoder_state_dict': decoder.state_dict(), 'encoder_optim': encoder_optim.state_dict(), 'decoder_optim': decoder_optim.state_dict(), 'best_acc': acc_score, 'epoch': epoch + 1, } if not os.path.isdir(args.checkpoint_dir): os.mkdir(args.checkpoint_dir) torch.save( state, os.path.join(args.checkpoint_dir, f'chkpt_full_{epoch}.pkl'))
config = CINConfig() dataset_val = OOIDataset("val") val_set = Dataset(dataset_val, config) # train_set.__getitem__(5) # train_set.__getitem__(20) # print("dataset_train", dataset_train.num_images) # print("train_set", train_set.__len__()) # def my_collate_fn(batch): batch = list(filter(lambda x: x is not None, batch)) if len(batch) == 0: print("No valid data!!!") batch = [[torch.from_numpy(np.zeros([1, 1]))]] return default_collate(batch) val_generator = TorchDataLoader(val_set, collate_fn=my_collate_fn, batch_size=1, shuffle=True, num_workers=1) step = 0 for inputs in val_generator: if len(inputs) != 17: print("length of inputs", len(inputs)) print("inputs", inputs) continue print(str(step)+"/9000") step+=1 ''' result=load_image_gt(ooiDataset,OISMPSConfig(),9,True,True) thing_masks=result[4] stuff_masks=result[7] semantic_label=result[8] influence_mask=result[11]