def get_train_loader(args): """get the train loader""" data_folder = os.path.join(args.data_folder, 'train') normalize = transforms.Normalize(mean=[(0 + 100) / 2, (-86.183 + 98.233) / 2, (-107.857 + 94.478) / 2], std=[(100 - 0) / 2, (86.183 + 98.233) / 2, (107.857 + 94.478) / 2]) train_transform = transforms.Compose([ transforms.RandomResizedCrop(224, scale=(args.crop_low, 1.)), transforms.RandomHorizontalFlip(), RGB2Lab(), transforms.ToTensor(), normalize, ]) train_dataset = ImageFolderInstance(data_folder, transform=train_transform) train_sampler = None # train loader train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.num_workers, pin_memory=True, sampler=train_sampler) # num of samples n_data = len(train_dataset) print('number of samples: {}'.format(n_data)) return train_loader, n_data
def get_train_loader(args): """get the train loader""" if 'imagenet' in args.dataset: data_folder = os.path.join(args.data_folder, 'train') if args.view == 'Lab': mean = [(0 + 100) / 2, (-86.183 + 98.233) / 2, (-107.857 + 94.478) / 2] std = [(100 - 0) / 2, (86.183 + 98.233) / 2, (107.857 + 94.478) / 2] color_transfer = RGB2Lab() elif args.view == 'YCbCr': mean = [116.151, 121.080, 132.342] std = [109.500, 111.855, 111.964] color_transfer = RGB2YCbCr() else: raise NotImplemented('view not implemented {}'.format(args.view)) normalize = transforms.Normalize(mean=mean, std=std) train_transform = transforms.Compose([ transforms.RandomResizedCrop(224, scale=(args.crop_low, 1.)), transforms.RandomHorizontalFlip(), color_transfer, transforms.ToTensor(), normalize, ]) train_dataset = ImageFolderInstance(data_folder, transform=train_transform) else: assert args.dataset == 'stl10' assert args.view == 'Lab' mean = [(0 + 100) / 2, (-86.183 + 98.233) / 2, (-107.857 + 94.478) / 2] std = [(100 - 0) / 2, (86.183 + 98.233) / 2, (107.857 + 94.478) / 2] train_transform = transforms.Compose([ # transforms.RandomCrop(64), transforms.RandomResizedCrop(64, scale=(args.crop_low, 1)), transforms.RandomHorizontalFlip(), RGB2Lab(), transforms.ToTensor(), transforms.Normalize(mean=mean, std=std) ]) train_dataset = datasets.STL10( args.data_folder, 'train+unlabeled', transform=train_transform, download=True) train_dataset = DatasetInstance(train_dataset) train_sampler = None # train loader train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.num_workers, pin_memory=True, sampler=train_sampler) # num of samples n_data = len(train_dataset) print('number of samples: {}'.format(n_data)) return train_loader, n_data
def main(): train_dataset = ImageFolderInstance( args.data, transforms.Compose([ transforms.CenterCrop(224), transforms.ColorJitter(0.2, 0.2, 0.2, 0.2), transforms.ToTensor(), ]), transforms.Compose([ transforms.CenterCrop(224), transforms.ToTensor(), ])) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) model = fcn_resnet50(num_classes=1) model = torch.nn.DataParallel(model).cuda() if not args.loss: criterion = diceLoss().cuda() else: criterion = nn.CrossEntropyLoss().cuda() optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # cudnn.benchmark = True for epoch in range(args.epochs): adjust_learning_rate(optimizer, epoch) # train for one epoch train(train_loader, model, criterion, optimizer, epoch) # remember best prec@1 and save checkpoint if (epoch % 5) == 0: save_checkpoint({ 'epoch': epoch + 1, 'arch': "res50", 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), })
def get_train_loader(args): """get the train loader""" data_folder = os.path.join(args.data_folder, 'train') if args.view == 'Lab': mean = [(0 + 100) / 2, (-86.183 + 98.233) / 2, (-107.857 + 94.478) / 2] std = [(100 - 0) / 2, (86.183 + 98.233) / 2, (107.857 + 94.478) / 2] color_transfer = RGB2Lab() elif args.view == 'YCbCr': mean = [116.151, 121.080, 132.342] std = [109.500, 111.855, 111.964] color_transfer = RGB2YCbCr() else: raise NotImplemented('view not implemented {}'.format(args.view)) normalize = transforms.Normalize(mean=mean, std=std) train_transform = transforms.Compose([ transforms.RandomResizedCrop(224, scale=(args.crop_low, 1.)), transforms.RandomHorizontalFlip(), color_transfer, transforms.ToTensor(), normalize, ]) train_dataset = ImageFolderInstance(data_folder, transform=train_transform) train_sampler = None train_samples = train_dataset.dataset.samples train_labels = [train_samples[i][1] for i in range(len(train_samples))] train_ordered_labels = np.array(train_labels) # train loader train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.num_workers, pin_memory=True, sampler=train_sampler) # num of samples n_data = len(train_dataset) print('number of samples: {}'.format(n_data)) return train_loader, train_ordered_labels, n_data
def get_test_loader(args): """get the train loader""" data_folder = os.path.join(args.data_folder, 'validation') if args.view == 'Lab': mean = [(0 + 100) / 2, (-86.183 + 98.233) / 2, (-107.857 + 94.478) / 2] std = [(100 - 0) / 2, (86.183 + 98.233) / 2, (107.857 + 94.478) / 2] color_transfer = RGB2Lab() elif args.view == 'YCbCr': mean = [116.151, 121.080, 132.342] std = [109.500, 111.855, 111.964] color_transfer = RGB2YCbCr() else: raise NotImplemented('view not implemented {}'.format(args.view)) normalize = transforms.Normalize(mean=mean, std=std) test_transform = transforms.Compose([ transforms.Resize(256), # FIXME: hardcoded for 224 image size transforms.CenterCrop(image_size), color_transfer, transforms.ToTensor(), normalize, ]) test_dataset = ImageFolderInstance(data_folder, transform=test_transform) test_sampler = None # train loader test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=(test_sampler is None), num_workers=args.num_workers, pin_memory=True, sampler=test_sampler) # num of samples n_data = len(test_dataset) print('number of samples: {}'.format(n_data)) return test_loader, n_data
def main(): args = parse_option() if args.gpu is not None: print("Use GPU: {} for training".format(args.gpu)) # set the data loader data_folder = os.path.join(args.data_folder, 'train') image_size = 224 mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] normalize = transforms.Normalize(mean=mean, std=std) if args.aug == 'NULL': train_transform = transforms.Compose([ transforms.RandomResizedCrop(image_size, scale=(args.crop, 1.)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ]) elif args.aug == 'CJ': train_transform = transforms.Compose([ transforms.RandomResizedCrop(image_size, scale=(args.crop, 1.)), transforms.RandomGrayscale(p=0.2), transforms.ColorJitter(0.4, 0.4, 0.4, 0.4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ]) else: raise NotImplemented('augmentation not supported: {}'.format(args.aug)) train_dataset = ImageFolderInstance(data_folder, transform=train_transform, two_crop=args.moco) print(len(train_dataset)) train_sampler = None train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.num_workers, pin_memory=True, sampler=train_sampler) # create model and optimizer n_data = len(train_dataset) if args.model == 'resnet50': model = InsResNet50() if args.moco: model_ema = InsResNet50() elif args.model == 'resnet50x2': model = InsResNet50(width=2) if args.moco: model_ema = InsResNet50(width=2) elif args.model == 'resnet50x4': model = InsResNet50(width=4) if args.moco: model_ema = InsResNet50(width=4) else: raise NotImplementedError('model not supported {}'.format(args.model)) # copy weights from `model' to `model_ema' if args.moco: moment_update(model, model_ema, 0) # set the contrast memory and criterion if args.moco: contrast = MemoryMoCo(128, n_data, args.nce_k, args.nce_t, args.softmax).cuda(args.gpu) else: contrast = MemoryInsDis(128, n_data, args.nce_k, args.nce_t, args.nce_m, args.softmax).cuda(args.gpu) criterion = NCESoftmaxLoss() if args.softmax else NCECriterion(n_data) criterion = criterion.cuda(args.gpu) model = model.cuda() if args.moco: model_ema = model_ema.cuda() optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) cudnn.benchmark = True if args.amp: model, optimizer = amp.initialize(model, optimizer, opt_level=args.opt_level) if args.moco: optimizer_ema = torch.optim.SGD(model_ema.parameters(), lr=0, momentum=0, weight_decay=0) model_ema, optimizer_ema = amp.initialize(model_ema, optimizer_ema, opt_level=args.opt_level) # optionally resume from a checkpoint args.start_epoch = 1 if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume, map_location='cpu') # checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] + 1 model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) contrast.load_state_dict(checkpoint['contrast']) if args.moco: model_ema.load_state_dict(checkpoint['model_ema']) if args.amp and checkpoint['opt'].amp: print('==> resuming amp state_dict') amp.load_state_dict(checkpoint['amp']) print("=> loaded successfully '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) del checkpoint torch.cuda.empty_cache() else: print("=> no checkpoint found at '{}'".format(args.resume)) # tensorboard logger = tb_logger.Logger(logdir=args.tb_folder, flush_secs=2) # routine for epoch in range(args.start_epoch, args.epochs + 1): adjust_learning_rate(epoch, args, optimizer) print("==> training...") time1 = time.time() if args.moco: loss, prob = train_moco(epoch, train_loader, model, model_ema, contrast, criterion, optimizer, args) else: loss, prob = train_ins(epoch, train_loader, model, contrast, criterion, optimizer, args) time2 = time.time() print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1)) # tensorboard logger logger.log_value('ins_loss', loss, epoch) logger.log_value('ins_prob', prob, epoch) logger.log_value('learning_rate', optimizer.param_groups[0]['lr'], epoch) # save model if epoch % args.save_freq == 0: print('==> Saving...') state = { 'opt': args, 'model': model.state_dict(), 'contrast': contrast.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch, } if args.moco: state['model_ema'] = model_ema.state_dict() if args.amp: state['amp'] = amp.state_dict() save_file = os.path.join( args.model_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch)) torch.save(state, save_file) # help release GPU memory del state # saving the model print('==> Saving...') state = { 'opt': args, 'model': model.state_dict(), 'contrast': contrast.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch, } if args.moco: state['model_ema'] = model_ema.state_dict() if args.amp: state['amp'] = amp.state_dict() save_file = os.path.join(args.model_folder, 'current.pth') torch.save(state, save_file) if epoch % args.save_freq == 0: save_file = os.path.join( args.model_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch)) torch.save(state, save_file) # help release GPU memory del state torch.cuda.empty_cache()
def main(): args = parse_option() if args.gpu is not None: print("Use GPU: {} for training".format(args.gpu)) # set the data loader data_folder = os.path.join(args.data_folder, 'train') val_folder = os.path.join(args.data_folder, 'val') crop_padding = 32 image_size = 224 mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] normalize = transforms.Normalize(mean=mean, std=std) if args.aug == 'NULL' and args.dataset == 'imagenet': train_transform = transforms.Compose([ transforms.RandomResizedCrop(image_size, scale=(args.crop, 1.)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ]) elif args.aug == 'CJ': train_transform = transforms.Compose([ transforms.RandomResizedCrop(image_size, scale=(args.crop, 1.)), transforms.RandomGrayscale(p=0.2), transforms.ColorJitter(0.4, 0.4, 0.4, 0.4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ]) # elif args.aug == 'NULL' and args.dataset == 'cifar': # train_transform = transforms.Compose([ # transforms.RandomResizedCrop(size=32, scale=(0.2, 1.)), # transforms.ColorJitter(0.4, 0.4, 0.4, 0.4), # transforms.RandomGrayscale(p=0.2), # transforms.RandomHorizontalFlip(p=0.5), # transforms.ToTensor(), # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), # ]) # # test_transform = transforms.Compose([ # transforms.ToTensor(), # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), # ]) elif args.aug == 'simple' and args.dataset == 'imagenet': train_transform = transforms.Compose([ transforms.RandomResizedCrop(image_size, scale=(args.crop, 1.)), transforms.RandomHorizontalFlip(), get_color_distortion(1.0), transforms.ToTensor(), normalize, ]) # TODO: Currently follow CMC test_transform = transforms.Compose([ transforms.Resize(image_size + crop_padding), transforms.CenterCrop(image_size), transforms.ToTensor(), normalize, ]) elif args.aug == 'simple' and args.dataset == 'cifar': train_transform = transforms.Compose([ transforms.RandomResizedCrop(size=32), transforms.RandomHorizontalFlip(p=0.5), get_color_distortion(0.5), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) test_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) else: raise NotImplemented('augmentation not supported: {}'.format(args.aug)) # Get Datasets if args.dataset == "imagenet": train_dataset = ImageFolderInstance(data_folder, transform=train_transform, two_crop=args.moco) print(len(train_dataset)) train_sampler = None train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.num_workers, pin_memory=True, sampler=train_sampler) test_dataset = datasets.ImageFolder(val_folder, transforms=test_transform) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=args.num_workers, pin_memory=True) elif args.dataset == 'cifar': # cifar-10 dataset if args.contrastive_model == 'simclr': train_dataset = CIFAR10Instance_double(root='./data', train=True, download=True, transform=train_transform, double=True) else: train_dataset = CIFAR10Instance(root='./data', train=True, download=True, transform=train_transform) train_sampler = None train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.num_workers, pin_memory=True, sampler=train_sampler, drop_last=True) test_dataset = CIFAR10Instance(root='./data', train=False, download=True, transform=test_transform) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=100, shuffle=False, num_workers=args.num_workers) # create model and optimizer n_data = len(train_dataset) if args.model == 'resnet50': model = InsResNet50() if args.contrastive_model == 'moco': model_ema = InsResNet50() elif args.model == 'resnet50x2': model = InsResNet50(width=2) if args.contrastive_model == 'moco': model_ema = InsResNet50(width=2) elif args.model == 'resnet50x4': model = InsResNet50(width=4) if args.contrastive_model == 'moco': model_ema = InsResNet50(width=4) elif args.model == 'resnet50_cifar': model = InsResNet50_cifar() if args.contrastive_model == 'moco': model_ema = InsResNet50_cifar() else: raise NotImplementedError('model not supported {}'.format(args.model)) # copy weights from `model' to `model_ema' if args.contrastive_model == 'moco': moment_update(model, model_ema, 0) # set the contrast memory and criterion if args.contrastive_model == 'moco': contrast = MemoryMoCo(128, n_data, args.nce_k, args.nce_t, args.softmax).cuda(args.gpu) elif args.contrastive_model == 'simclr': contrast = None else: contrast = MemoryInsDis(128, n_data, args.nce_k, args.nce_t, args.nce_m, args.softmax).cuda(args.gpu) if args.softmax: criterion = NCESoftmaxLoss() elif args.contrastive_model == 'simclr': criterion = BatchCriterion(1, args.nce_t, args.batch_size) else: criterion = NCECriterion(n_data) criterion = criterion.cuda(args.gpu) model = model.cuda() if args.contrastive_model == 'moco': model_ema = model_ema.cuda() # Exclude BN and bias if needed weight_decay = args.weight_decay if weight_decay and args.filter_weight_decay: parameters = add_weight_decay(model, weight_decay, args.filter_weight_decay) weight_decay = 0. else: parameters = model.parameters() optimizer = torch.optim.SGD(parameters, lr=args.learning_rate, momentum=args.momentum, weight_decay=weight_decay) cudnn.benchmark = True if args.amp: model, optimizer = amp.initialize(model, optimizer, opt_level=args.opt_level) if args.contrastive_model == 'moco': optimizer_ema = torch.optim.SGD(model_ema.parameters(), lr=0, momentum=0, weight_decay=0) model_ema, optimizer_ema = amp.initialize(model_ema, optimizer_ema, opt_level=args.opt_level) if args.LARS: optimizer = LARS(optimizer=optimizer, eps=1e-8, trust_coef=0.001) # optionally resume from a checkpoint args.start_epoch = 0 if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume, map_location='cpu') # checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] + 1 model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) if contrast: contrast.load_state_dict(checkpoint['contrast']) if args.contrastive_model == 'moco': model_ema.load_state_dict(checkpoint['model_ema']) if args.amp and checkpoint['opt'].amp: print('==> resuming amp state_dict') amp.load_state_dict(checkpoint['amp']) print("=> loaded successfully '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) del checkpoint torch.cuda.empty_cache() else: print("=> no checkpoint found at '{}'".format(args.resume)) # tensorboard logger = tb_logger.Logger(logdir=args.tb_folder, flush_secs=2) # routine for epoch in range(args.start_epoch, args.epochs + 1): print("==> training...") time1 = time.time() if args.contrastive_model == 'moco': loss, prob = train_moco(epoch, train_loader, model, model_ema, contrast, criterion, optimizer, args) elif args.contrastive_model == 'simclr': print("Train using simclr") loss, prob = train_simclr(epoch, train_loader, model, criterion, optimizer, args) else: print("Train using InsDis") loss, prob = train_ins(epoch, train_loader, model, contrast, criterion, optimizer, args) time2 = time.time() print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1)) # tensorboard logger logger.log_value('ins_loss', loss, epoch) logger.log_value('ins_prob', prob, epoch) logger.log_value('learning_rate', optimizer.param_groups[0]['lr'], epoch) test_epoch = 2 if epoch % test_epoch == 0: model.eval() if args.contrastive_model == 'moco': model_ema.eval() print('----------Evaluation---------') start = time.time() if args.dataset == 'cifar': acc = kNN(epoch, model, train_loader, test_loader, 200, args.nce_t, n_data, low_dim=128, memory_bank=None) print("Evaluation Time: '{}'s".format(time.time() - start)) # writer.add_scalar('nn_acc', acc, epoch) logger.log_value('Test accuracy', acc, epoch) # print('accuracy: {}% \t (best acc: {}%)'.format(acc, best_acc)) print('[Epoch]: {}'.format(epoch)) print('accuracy: {}%)'.format(acc)) # test_log_file.flush() # save model if epoch % args.save_freq == 0: print('==> Saving...') state = { 'opt': args, 'model': model.state_dict(), # 'contrast': contrast.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch, } if args.contrastive_model == 'moco': state['model_ema'] = model_ema.state_dict() if args.amp: state['amp'] = amp.state_dict() save_file = os.path.join( args.model_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch)) torch.save(state, save_file) # help release GPU memory del state # saving the model print('==> Saving...') state = { 'opt': args, 'model': model.state_dict(), # 'contrast': contrast.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch, } if args.contrastive_model == 'moco': state['model_ema'] = model_ema.state_dict() if args.amp: state['amp'] = amp.state_dict() save_file = os.path.join(args.model_folder, 'current.pth') torch.save(state, save_file) if epoch % args.save_freq == 0: save_file = os.path.join( args.model_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch)) torch.save(state, save_file) # help release GPU memory del state torch.cuda.empty_cache()
def get_train_loader(args): """get the train loader""" data_folder = os.path.join(args.data_folder, 'train') if args.view == 'Lab': mean = [(0 + 100) / 2, (-86.183 + 98.233) / 2, (-107.857 + 94.478) / 2] std = [(100 - 0) / 2, (86.183 + 98.233) / 2, (107.857 + 94.478) / 2] color_transfer = RGB2Lab() elif args.view == 'YCbCr': mean = [116.151, 121.080, 132.342] std = [109.500, 111.855, 111.964] color_transfer = RGB2YCbCr() else: raise NotImplemented('view not implemented {}'.format(args.view)) normalize = transforms.Normalize(mean=mean, std=std) train_transform = transforms.Compose([ transforms.RandomResizedCrop(224, scale=(args.crop_low, 1.)), transforms.RandomHorizontalFlip(), color_transfer, transforms.ToTensor(), normalize, ]) train_dataset = ImageFolderInstance(data_folder, transform=train_transform) train_sampler = None if args.IM: # print("using IM space.................") if args.IM_type == 'IM': print("using IM space.................") train_dataset = IM(train_dataset, g_alpha=args.g_alpha, g_num_mix=args.g_num, g_prob=args.g_prob, r_beta=args.r_beta, r_prob=args.r_prob, r_num_mix=args.r_num, r_decay=args.r_pixel_decay) if args.IM_type == 'global': print("using global space.................") train_dataset = global_(train_dataset, g_alpha=args.g_alpha, g_num_mix=args.g_num, g_prob=args.g_prob) if args.IM_type == 'region': print("using region space.................") train_dataset = region(train_dataset, r_beta=args.r_beta, r_prob=args.r_prob, r_num_mix=args.r_num, r_decay=args.r_pixel_decay) if args.IM_type == 'Cutout': print("using Cutout aug.................") train_dataset = Cutout(train_dataset, mask_size=args.mask_size, p=args.cutout_p, cutout_inside=args.cutout_inside, mask_color=args.mask_color) if args.IM_type == 'RandomErasing': print("using RandomErasing aug.................") train_dataset = RandomErasing( train_dataset, p=args.random_erasing_prob, area_ratio_range=args.area_ratio_range, min_aspect_ratio=args.min_aspect_ratio, max_attempt=args.max_attempt) # train loader train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.num_workers, pin_memory=True, sampler=train_sampler) # num of samples n_data = len(train_dataset) print('number of samples: {}'.format(n_data)) return train_loader, n_data