dataloader_val = torch.utils.data.DataLoader(dataset=dataset_val, batch_size=batch_size, num_workers=10, collate_fn=collate, shuffle=False, pin_memory=True) dataset_test = IDRID(data_path, val_df, label=False) dataloader_test = torch.utils.data.DataLoader(dataset=dataset_test, batch_size=batch_size, num_workers=10, collate_fn=collate_test, shuffle=False, pin_memory=True) ##### sizes are (w, h) ############################## # make sure margin / 32 is over 1.5 AND size_g is divisible by 4 size_g = (args.size_g, args.size_g) # resized global image size_p = (args.size_p, args.size_p) # cropped local patch size sub_batch_size = args.sub_batch_size # batch size for train local patches ################################### print("creating models......") path_g = os.path.join(model_path, args.path_g) path_g2l = os.path.join(model_path, args.path_g2l) path_l2g = os.path.join(model_path, args.path_l2g) model, global_fixed = create_model_load_weights(n_class, mode, evaluation, path_g=path_g, path_g2l=path_g2l, path_l2g=path_l2g) ################################### num_epochs = args.num_epochs learning_rate = args.lr lamb_fmreg = args.lamb_fmreg optimizer = get_optimizer(model, mode, learning_rate=learning_rate) scheduler = LR_Scheduler('poly', learning_rate, num_epochs, len(dataloader_train)) ################################## criterion1 = FocalLoss(gamma=3) criterion2 = nn.CrossEntropyLoss() criterion3 = lovasz_softmax criterion = lambda x,y: criterion1(x, y)
print("preparing datasets and dataloaders......") batch_size = args.batch_size data_time = AverageMeter("DataTime", ':6.3f') batch_time = AverageMeter("BatchTime", ':6.3f') dataset_train = get_dataset(data_path, meta_path, train=True) dataloader_train = DataLoader(dataset_train, num_workers=4, batch_size=batch_size, collate_fn=collate, shuffle=True, pin_memory=True) dataset_val = get_dataset(data_path, meta_path, train=False) dataloader_val = DataLoader(dataset_val, num_workers=4, batch_size=batch_size, collate_fn=collate, shuffle=False, pin_memory=True) ################################### print("creating models......") path = os.path.join(model_path, args.path_test) if args.path_test else args.path_test model = create_model_load_weights(n_class, evaluation, path=path) ################################### num_epochs = args.epoch learning_rate = args.lr momentum = args.momentum weight_decay = args.weight_decay opt_args = dict(lr=learning_rate, momentum=momentum, weight_decay=weight_decay) optimizer = get_optimizer(model, **opt_args) scheduler = LR_Scheduler('poly', learning_rate, num_epochs, len(dataloader_train)) ################################## criterion = BCELoss() if not evaluation: writer = SummaryWriter(log_dir=log_path + task_name)