batch_size=cfg.BATCH_SIZE, shuffle=True, num_workers=cfg.WORKERS, pin_memory=True, drop_last=True) val_loader = DataLoader(val_data, batch_size=8, shuffle=False, num_workers=cfg.WORKERS, pin_memory=True) # val_loader=None unlabeled_loader = None num_train = len(train_data) num_val = len(val_data) cfg.CLASS_WEIGHTS_TRAIN = train_data.class_weights cfg.IGNORE_LABEL = train_data.ignore_label # shell script to run print('LOSS_TYPES:', cfg.LOSS_TYPES) writer = SummaryWriter(log_dir=cfg.LOG_PATH) # tensorboard if cfg.MULTI_MODAL: model = TRecgNet_MULTIMODAL(cfg, writer=writer) else: model = TRecgNet(cfg, writer=writer) model.set_data_loader(train_loader, val_loader, unlabeled_loader, num_train, num_val) def train():
cfg, data_dir=cfg.DATA_DIR_VAL, transform=transforms.Compose([ SPL10.Resize((cfg.LOAD_SIZE, cfg.LOAD_SIZE)), SPL10.CenterCrop((cfg.FINE_SIZE, cfg.FINE_SIZE)), RGB2Lab(), SPL10.ToTensor(), SPL10.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ])) train_loader = DataProvider(cfg, dataset=train_dataset) val_loader = DataProvider(cfg, dataset=val_dataset, shuffle=False) num_classes_train = list( Counter([i[1] for i in train_loader.dataset.imgs]).values()) cfg.CLASS_WEIGHTS_TRAIN = torch.FloatTensor(num_classes_train) model = Contrastive_CrossModal_Conc(cfg, device=device) model = nn.DataParallel(model).to(device) optim = Adam(model.parameters(), lr=cfg.LR) load_model = False if load_model: model = torch.load('./checkpoint/model_1_LAB.mdl') print("load pretrained model") # loss_optim = Adam(infomax_fn.parameters(), lr=2e-4) # cls_criterion = torch.nn.CrossEntropyLoss(cfg.CLASS_WEIGHTS_TRAIN.to(device)) scheduler_optim = get_scheduler(optim) # scheduler_loss_optim = get_scheduler(loss_optim) epoch_restart = None
def main(): cfg = DefaultConfig() args = { 'seg_resnet_sunrgbd': SEG_RESNET_SUNRGBD_CONFIG().args(), 'seg_resnet_cityscapes': SEG_RESNET_CITYSCAPE_CONFIG().args(), 'rec_resnet_sunrgbd': REC_RESNET_SUNRGBD_CONFIG().args(), 'rec_resnet_nyud2': REC_RESNET_NYUD2_CONFIG().args(), 'rec_resnet_mit67': REC_RESNET_MIT67_CONFIG().args(), 'infomax_resnet_sunrgbd': INFOMAX_RESNET_SUNRGBD_CONFIG().args(), 'infomax_resnet_nyud2': INFOMAX_RESNET_NYUD2_CONFIG().args() } # use shell if len(sys.argv) > 1: device_ids = torch.cuda.device_count() print('device_ids:', device_ids) gpu_ids, config_key = sys.argv[1:] cfg.parse(args[config_key]) cfg.GPU_IDS = gpu_ids.split(',') else: # seg_resnet_sunrgbd # seg_resnet_cityscapes # infomax_resnet_sunrgbd # rec_resnet_sunrgbd # rec_resnet_nyud2 # rec_resnet_mit67 # infomax_resnet_nyud2 config_key = 'rec_resnet_sunrgbd' cfg.parse(args[config_key]) os.environ['CUDA_VISIBLE_DEVICES'] = ','.join( map(lambda x: str(x), cfg.GPU_IDS)) trans_task = '' + cfg.WHICH_DIRECTION if not cfg.NO_TRANS: if cfg.MULTI_MODAL: trans_task = trans_task + '_multimodal_' if 'SEMANTIC' in cfg.LOSS_TYPES or 'PIX2PIX' in cfg.LOSS_TYPES: trans_task = trans_task + '_alpha_' + str(cfg.ALPHA_CONTENT) evaluate_type = 'sliding_window' if cfg.SLIDE_WINDOWS else 'center_crop' log_name = ''.join( [cfg.TASK, '_', cfg.ARCH, '_', trans_task, '_', cfg.DATASET]) cfg.LOG_NAME = ''.join([ log_name, '_', '.'.join(cfg.LOSS_TYPES), '_', evaluate_type, '_gpus_', str(len(cfg.GPU_IDS)), '_', datetime.now().strftime('%b%d_%H-%M-%S') ]) cfg.LOG_PATH = os.path.join(cfg.LOG_PATH, cfg.MODEL, cfg.LOG_NAME) # Setting random seed if cfg.MANUAL_SEED is None: cfg.MANUAL_SEED = random.randint(1, 10000) random.seed(cfg.MANUAL_SEED) torch.manual_seed(cfg.MANUAL_SEED) torch.backends.cudnn.benchmark = True # cudnn.deterministic = True project_name = reduce(lambda x, y: str(x) + '/' + str(y), os.path.realpath(__file__).split(os.sep)[:-1]) print('>>> task path is {0}'.format(project_name)) util.mkdir('logs') # dataset = segmentation_dataset_cv2 train_transforms = list() val_transforms = list() ms_targets = [] train_transforms.append(dataset.Resize(cfg.LOAD_SIZE)) # train_transforms.append(dataset.RandomScale(cfg.RANDOM_SCALE_SIZE)) # # train_transforms.append(dataset.RandomRotate()) # train_transforms.append(dataset.RandomCrop_Unaligned(cfg.FINE_SIZE, pad_if_needed=True, fill=0)) # train_transforms.append( dataset.RandomCrop(cfg.FINE_SIZE, pad_if_needed=True, fill=0)) # train_transforms.append(dataset.RandomHorizontalFlip()) if cfg.TARGET_MODAL == 'lab': train_transforms.append(dataset.RGB2Lab()) if cfg.MULTI_SCALE: for item in cfg.MULTI_TARGETS: ms_targets.append(item) train_transforms.append( dataset.MultiScale(size=cfg.FINE_SIZE, scale_times=cfg.MULTI_SCALE_NUM, ms_targets=ms_targets)) train_transforms.append(dataset.ToTensor(ms_targets=ms_targets)) train_transforms.append( dataset.Normalize(mean=cfg.MEAN, std=cfg.STD, ms_targets=ms_targets)) val_transforms.append(dataset.Resize(cfg.LOAD_SIZE)) if not cfg.SLIDE_WINDOWS: val_transforms.append(dataset.CenterCrop((cfg.FINE_SIZE))) if cfg.MULTI_SCALE: val_transforms.append( dataset.MultiScale(size=cfg.FINE_SIZE, scale_times=cfg.MULTI_SCALE_NUM, ms_targets=ms_targets)) val_transforms.append(dataset.ToTensor(ms_targets=ms_targets)) val_transforms.append( dataset.Normalize(mean=cfg.MEAN, std=cfg.STD, ms_targets=ms_targets)) train_dataset = dataset.__dict__[cfg.DATASET]( cfg=cfg, transform=transforms.Compose(train_transforms), data_dir=cfg.DATA_DIR_TRAIN, phase_train=True) val_dataset = dataset.__dict__[cfg.DATASET]( cfg=cfg, transform=transforms.Compose(val_transforms), data_dir=cfg.DATA_DIR_VAL, phase_train=False) cfg.CLASS_WEIGHTS_TRAIN = train_dataset.class_weights cfg.IGNORE_LABEL = train_dataset.ignore_label cfg.train_dataset = train_dataset cfg.val_dataset = val_dataset port = random.randint(8001, 9000) ngpus_per_node = len(cfg.GPU_IDS) if cfg.MULTIPROCESSING_DISTRIBUTED: cfg.rank = 0 cfg.ngpus_per_node = ngpus_per_node cfg.dist_url = 'tcp://127.0.0.1:' + str(port) cfg.dist_backend = 'nccl' cfg.opt_level = 'O0' cfg.world_size = 1 cfg.print_args() if cfg.MULTIPROCESSING_DISTRIBUTED: cfg.world_size = cfg.ngpus_per_node * cfg.world_size mp.spawn(main_worker, nprocs=cfg.ngpus_per_node, args=(cfg.ngpus_per_node, cfg)) else: # Simply call main_worker function main_worker(cfg.GPU_IDS, ngpus_per_node, cfg)
def train(): cfg = DefaultConfig() args = { 'resnet18': RESNET18_SUNRGBD_CONFIG().args(), } # Setting random seed if cfg.MANUAL_SEED is None: cfg.MANUAL_SEED = random.randint(1, 10000) random.seed(cfg.MANUAL_SEED) torch.manual_seed(cfg.MANUAL_SEED) # args for different backbones cfg.parse(args['resnet18']) os.environ["CUDA_VISIBLE_DEVICES"] = cfg.GPU_IDS device_ids = torch.cuda.device_count() print('device_ids:', device_ids) project_name = reduce(lambda x, y: str(x) + '/' + str(y), os.path.realpath(__file__).split(os.sep)[:-1]) util.mkdir('logs') # data train_dataset = dataset.AlignedConcDataset( cfg, data_dir=cfg.DATA_DIR_TRAIN, transform=transforms.Compose([ dataset.Resize((cfg.LOAD_SIZE, cfg.LOAD_SIZE)), dataset.RandomCrop((cfg.FINE_SIZE, cfg.FINE_SIZE)), dataset.RandomHorizontalFlip(), dataset.ToTensor(), dataset.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ])) val_dataset = dataset.AlignedConcDataset( cfg, data_dir=cfg.DATA_DIR_VAL, transform=transforms.Compose([ dataset.Resize((cfg.LOAD_SIZE, cfg.LOAD_SIZE)), dataset.CenterCrop((cfg.FINE_SIZE, cfg.FINE_SIZE)), dataset.ToTensor(), dataset.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ])) batch_size_val = cfg.BATCH_SIZE unlabeled_loader = None if cfg.UNLABELED: unlabeled_dataset = dataset.AlignedConcDataset( cfg, data_dir=cfg.DATA_DIR_UNLABELED, transform=transforms.Compose([ dataset.Resize((cfg.LOAD_SIZE, cfg.LOAD_SIZE)), dataset.RandomCrop((cfg.FINE_SIZE, cfg.FINE_SIZE)), dataset.RandomHorizontalFlip(), dataset.ToTensor(), dataset.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]), labeled=False) unlabeled_loader = DataProvider(cfg, dataset=unlabeled_dataset) train_loader = DataProvider(cfg, dataset=train_dataset, batch_size=batch_size_val) val_loader = DataProvider(cfg, dataset=val_dataset, batch_size=batch_size_val, shuffle=False) # class weights num_classes_train = list( Counter([i[1] for i in train_loader.dataset.imgs]).values()) cfg.CLASS_WEIGHTS_TRAIN = torch.FloatTensor(num_classes_train) writer = SummaryWriter(log_dir=cfg.LOG_PATH) # tensorboard model = TRecgNet(cfg, writer=writer) model.set_data_loader(train_loader, val_loader, unlabeled_loader) if cfg.RESUME: checkpoint_path = os.path.join(cfg.CHECKPOINTS_DIR, cfg.RESUME_PATH) checkpoint = torch.load(checkpoint_path) load_epoch = checkpoint['epoch'] model.load_checkpoint(model.net, checkpoint_path, checkpoint, data_para=True) cfg.START_EPOCH = load_epoch if cfg.INIT_EPOCH: # just load pretrained parameters print('load checkpoint from another source') cfg.START_EPOCH = 1 print('>>> task path is {0}'.format(project_name)) # train model.train_parameters(cfg) print('save model ...') model_filename = '{0}_{1}_{2}.pth'.format(cfg.MODEL, cfg.WHICH_DIRECTION, cfg.NITER_TOTAL) model.save_checkpoint(cfg.NITER_TOTAL, model_filename) if writer is not None: writer.close()