def main(config): os.makedirs('cache', exist_ok=True) os.makedirs(config.logdir, exist_ok=True) print("Logging to: %s" % config.logdir) if not os.path.exists(config.train_dir): print("KERNEL ENV") config.train_dicom_dir = '../input/siim-train-test/siim/dicom-images-train' config.test_dicom_dir = '../input/siim-train-test/siim/dicom-images-test' config.train_dir = '../input/l2-images/l2-images/l2-images-train' config.test_dir = '../input/l2-images/l2-images/l2-images-test' config.sample_submission = '../input/siim-acr-pneumothorax-segmentation/' \ 'sample_submission.csv' config.train_rle = '../input/siim-train-test/siim/train-rle.csv' train_image_fns = sorted(glob(os.path.join(config.train_dir, '*.png'))) test_image_fns = sorted(glob(os.path.join(config.test_dir, '*.png'))) assert len(train_image_fns) == 10675, len(train_image_fns) assert len(test_image_fns) in (1372, 1377), len(test_image_fns) gt = load_gt(config.train_rle) # create folds if not config.stratify: # random folds np.random.shuffle(train_image_fns) else: # folds stratified by mask size train_mask_sizes = [ L2DicomDataset.rles_to_mask(gt[L2DicomDataset.fn_to_id(fn)]).sum() for fn in tqdm(train_image_fns) ] sorted_inds = [ k for k in sorted(range(len(train_image_fns)), key=lambda k: train_mask_sizes[k]) ] train_image_fns = [train_image_fns[k] for k in sorted_inds] folds = np.arange(len(train_image_fns)) % config.num_folds val_image_fns = [ fn for k, fn in enumerate(train_image_fns) if folds[k] == config.fold ] train_image_fns = [ fn for k, fn in enumerate(train_image_fns) if folds[k] != config.fold ] # remove not-used files: # https://www.kaggle.com/c/siim-acr-pneumothorax-segmentation/discussion/98478#latest-572385 # noqa train_image_fns = [ fn for fn in train_image_fns if L2DicomDataset.fn_to_id(fn) in gt ] val_image_fns = [ fn for fn in val_image_fns if L2DicomDataset.fn_to_id(fn) in gt ] if config.drop_empty: # remove empty masks from training data non_empty_gt = {k: v for k, v in gt.items() if v[0] != ' -1'} train_image_fns = [ fn for fn in train_image_fns if L2DicomDataset.fn_to_id(fn) in non_empty_gt ] print("[Non-EMPTY] TRAIN: ", len(train_image_fns), os.path.basename(train_image_fns[0])) print("VAL: ", len(val_image_fns), os.path.basename(val_image_fns[0])) print("TRAIN: ", len(train_image_fns), os.path.basename(train_image_fns[0])) train_ds = L2DicomDataset(train_image_fns, gt_rles=gt, height=config.height, width=config.height, to_ram=True, augment=True, write_cache=not config.is_kernel, train_dicom_dir=config.train_dicom_dir, test_dicom_dir=config.test_dicom_dir) val_ds = L2DicomDataset(val_image_fns, gt_rles=gt, height=config.height, width=config.height, to_ram=True, write_cache=not config.is_kernel, train_dicom_dir=config.train_dicom_dir, test_dicom_dir=config.test_dicom_dir) val_loader = data.DataLoader(val_ds, batch_size=config.batch_size, shuffle=False, num_workers=config.num_workers, pin_memory=config.pin, drop_last=False) model = FPNSegmentation(config.slug, num_input_channels=2) if config.weight is not None: model.load_state_dict(th.load(config.weight)) model = model.to(config.device) optimizer = th.optim.Adam(model.parameters(), lr=config.lr, weight_decay=config.weight_decay) if config.apex: model, optimizer = apex.amp.initialize(model, optimizer, opt_level="O1", verbosity=0) updates_per_epoch = len(train_ds) // config.batch_size num_updates = int(config.epochs * updates_per_epoch) scheduler = WarmupLinearSchedule(warmup=config.warmup, t_total=num_updates) # training loop smooth = 0.1 best_dice = 0.0 best_fn = None global_step = 0 for epoch in range(config.epochs): smooth_loss = None smooth_accuracy = None model.train() train_loader = data.DataLoader(train_ds, batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers, pin_memory=config.pin, drop_last=True) progress = tqdm(total=len(train_ds), smoothing=0.01) for i, (X, y_true) in enumerate(train_loader): X = X.to(config.device) y_true = y_true.to(config.device) y_pred = model(X) loss = siim_loss(y_true, y_pred, weights=None) if config.apex: with apex.amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() lr_this_step = None if (i + 1) % config.accumulation_step == 0: optimizer.step() optimizer.zero_grad() lr_this_step = config.lr * scheduler.get_lr( global_step, config.warmup) for param_group in optimizer.param_groups: param_group['lr'] = lr_this_step global_step += 1 smooth_loss = loss.item() if smooth_loss is None else \ smooth * loss.item() + (1. - smooth) * smooth_loss # print((y_true >= 0.5).sum().item()) accuracy = th.mean( ((y_pred >= 0.5) == (y_true >= 0.5)).to(th.float)).item() smooth_accuracy = accuracy if smooth_accuracy is None else \ smooth * accuracy + (1. - smooth) * smooth_accuracy progress.set_postfix( loss='%.4f' % smooth_loss, accuracy='%.4f' % (smooth_accuracy), lr='%.6f' % (config.lr if lr_this_step is None else lr_this_step)) progress.update(len(X)) # validation loop model.eval() thresholds = np.arange(0.1, 0.7, 0.1) dice_coeffs = [[] for _ in range(len(thresholds))] progress = tqdm(enumerate(val_loader), total=len(val_loader)) with th.no_grad(): for i, (X, y_trues) in progress: X = X.to(config.device) y_trues = y_trues.to(config.device) y_preds = model(X) for yt, yp in zip(y_trues, y_preds): yt = (yt.squeeze().cpu().numpy() >= 0.5).astype('uint8') yp = yp.squeeze().cpu().numpy() for dind, threshold in enumerate(thresholds): yp_ = (yp >= threshold).astype(np.uint8) sc = score(yt, yp_) dice_coeffs[dind].append(sc) best_threshold_ind = -1 dice_coeff = -1 for dind, threshold in enumerate(thresholds): dc = np.mean( [x[0] for x in dice_coeffs[dind] if x[1] == 'non-empty']) # progress.write("Dice @%.2f: %.4f" % (threshold, dc)) if dc > dice_coeff: dice_coeff = dc best_threshold_ind = dind dice_coeffs = dice_coeffs[best_threshold_ind] num_empty = sum(1 for x in dice_coeffs if x[1] == 'empty') num_total = len(dice_coeffs) num_non_empty = num_total - num_empty empty_sum = np.sum([d[0] for d in dice_coeffs if d[1] == 'empty']) non_empty_sum = np.sum( [d[0] for d in dice_coeffs if d[1] == 'non-empty']) dice_coeff_empty = empty_sum / num_empty dice_coeff_non_empty = non_empty_sum / num_non_empty progress.write( '[Empty: %d]: %.3f | %.3f, [Non-Empty: %d]: %.3f | %.3f' % (num_empty, dice_coeff_empty, empty_sum / num_total, num_non_empty, dice_coeff_non_empty, non_empty_sum / num_total)) dice_coeff = float(dice_coeff) summary_str = 'f%02d-ep-%04d-val_dice-%.4f@%.2f' % ( config.fold, epoch, dice_coeff, thresholds[best_threshold_ind]) progress.write(summary_str) if dice_coeff > best_dice: weight_fn = os.path.join(config.logdir, summary_str + '.pth') th.save(model.state_dict(), weight_fn) best_dice = dice_coeff best_fn = weight_fn fns = sorted( glob(os.path.join(config.logdir, 'f%02d-*.pth' % config.fold))) for fn in fns[:-config.n_keep]: os.remove(fn) # create submission test_ds = L2DicomDataset(test_image_fns, height=config.height, width=config.height, write_cache=not config.is_kernel, train_dicom_dir=config.train_dicom_dir, test_dicom_dir=config.test_dicom_dir) test_loader = data.DataLoader(test_ds, batch_size=config.batch_size, shuffle=False, num_workers=0, pin_memory=False, drop_last=False) if best_fn is not None: model.load_state_dict(th.load(best_fn)) model.eval() sub = create_submission(model, test_loader, test_image_fns, config, pred_zip=config.pred_zip) sub.to_csv(config.submission_fn, index=False) print("Wrote to: %s" % config.submission_fn) # create val submission val_fn = config.submission_fn.replace('.csv', '_VAL.csv') model.eval() sub = [] sub = create_submission(model, val_loader, val_image_fns, config, pred_zip=config.pred_zip.replace( '.zip', '_VAL.zip')) sub.to_csv(val_fn, index=False) print("Wrote to: %s" % val_fn)
def main(config): seed_all() os.makedirs('cache', exist_ok=True) os.makedirs(config.logdir, exist_ok=True) print("Logging to: %s" % config.logdir) src_files = sorted(glob('*.py')) for src_fn in src_files: dst_fn = os.path.join(config.logdir, src_fn) copyfile(src_fn, dst_fn) train_image_fns = sorted(glob(os.path.join(config.train_dir, '*/*/*.dcm'))) test_image_fns = sorted(glob(os.path.join(config.test_dir, '*/*/*.dcm'))) # assert len(train_image_fns) == 10712 # assert len(test_image_fns) == 1377 gt = load_gt(config.train_rle) # create folds np.random.shuffle(train_image_fns) if config.subset > 0: train_image_fns = train_image_fns[:config.subset] folds = np.arange(len(train_image_fns)) % config.num_folds val_image_fns = [fn for k, fn in enumerate(train_image_fns) if folds[k] == config.fold] train_image_fns = [fn for k, fn in enumerate(train_image_fns) if folds[k] != config.fold] # remove not-used files: # https://www.kaggle.com/c/siim-acr-pneumothorax-segmentation/discussion/98478#latest-572385 # noqa train_image_fns = [fn for fn in train_image_fns if DicomDataset.fn_to_id(fn) in gt] val_image_fns = [fn for fn in val_image_fns if DicomDataset.fn_to_id(fn) in gt] print("VAL: ", len(val_image_fns), os.path.basename(val_image_fns[0])) print("TRAIN: ", len(train_image_fns), os.path.basename(train_image_fns[0])) train_ds = DicomDataset(train_image_fns, gt_rles=gt, augment=True) val_ds = DicomDataset(val_image_fns, gt_rles=gt) if config.cache: train_ds.cache() val_ds.cache() val_loader = data.DataLoader(val_ds, batch_size=config.batch_size, shuffle=False, num_workers=config.num_workers, pin_memory=config.pin, drop_last=False) model = FPNSegmentation(config.slug, ema=config.ema) if config.weight is not None: print("Loading: %s" % config.weight) model.load_state_dict(th.load(config.weight)) model = model.to(config.device) no_decay = ['mean', 'std', 'bias'] + ['.bn%d.' % i for i in range(100)] grouped_parameters = [{'params': [], 'weight_decay': config.weight_decay}, {'params': [], 'weight_decay': 0.0}] for n, p in model.named_parameters(): if not any(nd in n for nd in no_decay): print("Decay: %s" % n) grouped_parameters[0]['params'].append(p) else: print("No Decay: %s" % n) grouped_parameters[1]['params'].append(p) optimizer = AdamW(grouped_parameters, lr=config.lr) if config.apex: model, optimizer = apex.amp.initialize(model, optimizer, opt_level="O1", verbosity=0) updates_per_epoch = len(train_ds) // config.batch_size num_updates = int(config.epochs * updates_per_epoch) scheduler = WarmupLinearSchedule(warmup=config.warmup, t_total=num_updates) # training loop smooth = 0.1 best_dice = 0.0 best_fn = None global_step = 0 for epoch in range(1, config.epochs + 1): smooth_loss = None smooth_accuracy = None model.train() train_loader = data.DataLoader(train_ds, batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers, pin_memory=config.pin, drop_last=True) progress = tqdm(total=len(train_ds), smoothing=0.01) for i, (X, _, y_true) in enumerate(train_loader): X = X.to(config.device).float() y_true = y_true.to(config.device) y_pred = model(X) loss = siim_loss(y_true, y_pred, weights=None) if config.apex: with apex.amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() lr_this_step = None if (i + 1) % config.accumulation_step == 0: optimizer.step() optimizer.zero_grad() lr_this_step = config.lr * scheduler.get_lr(global_step, config.warmup) for param_group in optimizer.param_groups: param_group['lr'] = lr_this_step global_step += 1 smooth_loss = loss.item() if smooth_loss is None else \ smooth * loss.item() + (1. - smooth) * smooth_loss # print((y_true >= 0.5).sum().item()) accuracy = th.mean(((y_pred >= 0.5) == (y_true == 1)).to( th.float)).item() smooth_accuracy = accuracy if smooth_accuracy is None else \ smooth * accuracy + (1. - smooth) * smooth_accuracy progress.set_postfix(ep='%d/%d' % (epoch, config.epochs), loss='%.4f' % smooth_loss, accuracy='%.4f' % (smooth_accuracy), lr='%.6f' % (config.lr if lr_this_step is None else lr_this_step)) progress.update(len(X)) if epoch <= 12: continue # validation loop model.eval() thresholds = [0.1, 0.2] dice_coeffs = [[] for _ in range(len(thresholds))] progress = tqdm(enumerate(val_loader), total=len(val_loader)) with th.no_grad(): for i, (X, _, y_trues) in progress: X = X.to(config.device).float() y_trues = y_trues.to(config.device) y_preds = model(X) y_preds_flip = th.flip(model(th.flip(X, (-1, ))), (-1, )) y_preds = 0.5 * (y_preds + y_preds_flip) y_trues = y_trues.cpu().numpy() y_preds = y_preds.cpu().numpy() for yt, yp in zip(y_trues, y_preds): yt = (yt.squeeze() >= 0.5).astype('uint8') yp = yp.squeeze() for dind, threshold in enumerate(thresholds): yp_ = (yp >= threshold).astype(np.uint8) sc = score(yt, yp_) dice_coeffs[dind].append(sc) best_threshold_ind = -1 dice_coeff = -1 for dind, threshold in enumerate(thresholds): dc = np.mean([x[0] for x in dice_coeffs[dind] if x[1] == 'non-empty']) # progress.write("Dice @%.2f: %.4f" % (threshold, dc)) if dc > dice_coeff: dice_coeff = dc best_threshold_ind = dind dice_coeffs = dice_coeffs[best_threshold_ind] num_empty = sum(1 for x in dice_coeffs if x[1] == 'empty') num_total = len(dice_coeffs) num_non_empty = num_total - num_empty empty_sum = np.sum([d[0] for d in dice_coeffs if d[1] == 'empty']) non_empty_sum = np.sum([d[0] for d in dice_coeffs if d[1] == 'non-empty']) dice_coeff_empty = empty_sum / num_empty dice_coeff_non_empty = non_empty_sum / num_non_empty progress.write('[Empty: %d]: %.3f | %.3f, [Non-Empty: %d]: %.3f | %.3f' % ( num_empty, dice_coeff_empty, empty_sum / num_total, num_non_empty, dice_coeff_non_empty, non_empty_sum / num_total)) dice_coeff = float(dice_coeff) summary_str = 'f%02d-ep-%04d-val_dice-%.4f@%.2f' % (config.fold, epoch, dice_coeff, thresholds[best_threshold_ind]) progress.write(summary_str) if dice_coeff > best_dice: weight_fn = os.path.join(config.logdir, summary_str + '.pth') th.save(model.state_dict(), weight_fn) best_dice = dice_coeff best_fn = weight_fn fns = sorted(glob(os.path.join(config.logdir, 'f%02d-*.pth' % config.fold))) for fn in fns[:-config.n_keep]: os.remove(fn) # create submission test_ds = DicomDataset(test_image_fns) test_loader = data.DataLoader(test_ds, batch_size=config.batch_size, shuffle=False, num_workers=0, pin_memory=False, drop_last=False) if best_fn is not None: model.load_state_dict(th.load(best_fn)) model.eval() sub = create_submission(model, test_loader, config, pred_zip=config.pred_zip) sub.to_csv(config.submission_fn, index=False) print("Wrote to: %s" % config.submission_fn) # create val submission val_fn = config.submission_fn.replace('.csv', '_VAL.csv') model.eval() sub = [] sub = create_submission(model, val_loader, config, pred_zip=config.pred_zip.replace('.zip', '_VAL.zip')) sub.to_csv(val_fn, index=False) print("Wrote to: %s" % val_fn)
def main(config): seed_all() os.makedirs('cache', exist_ok=True) os.makedirs(config.logdir, exist_ok=True) print("Logging to: %s" % config.logdir) src_files = sorted(glob('*.py')) for src_fn in src_files: dst_fn = os.path.join(config.logdir, src_fn) copyfile(src_fn, dst_fn) train_image_fns = sorted(glob(os.path.join(config.train_dir, '*.jpg'))) test_image_fns = sorted(glob(os.path.join(config.test_dir, '*.jpg'))) assert len(train_image_fns) == 3881 assert len(test_image_fns) == 4150 gt, label_to_int = load_gt(config.train_rle) int_to_label = {v: k for k, v in label_to_int.items()} # create folds np.random.shuffle(train_image_fns) if config.subset > 0: train_image_fns = train_image_fns[:config.subset] folds = np.arange(len(train_image_fns)) % config.num_folds val_image_fns = [ fn for k, fn in enumerate(train_image_fns) if folds[k] == config.fold ] train_image_fns = [ fn for k, fn in enumerate(train_image_fns) if folds[k] != config.fold ] if config.add_val: print("Training on validation set") train_image_fns = train_image_fns + val_image_fns[:] print(len(val_image_fns), len(train_image_fns)) # TODO: drop empty images <- is this helpful? train_image_fns = [ fn for fn in train_image_fns if KuzushijiDataset.fn_to_id(fn) in gt ] val_image_fns = [ fn for fn in val_image_fns if KuzushijiDataset.fn_to_id(fn) in gt ] print("VAL: ", len(val_image_fns), val_image_fns[123]) print("TRAIN: ", len(train_image_fns), train_image_fns[456]) train_ds = KuzushijiDataset(train_image_fns, gt_boxes=gt, label_to_int=label_to_int, augment=True) val_ds = KuzushijiDataset(val_image_fns, gt_boxes=gt, label_to_int=label_to_int) if config.cache: train_ds.cache() val_ds.cache() val_loader = data.DataLoader(val_ds, batch_size=config.batch_size // 8, shuffle=False, num_workers=config.num_workers, pin_memory=config.pin, drop_last=False) model = FPNSegmentation(config.slug) if config.weight is not None: print("Loading: %s" % config.weight) model.load_state_dict(th.load(config.weight)) model = model.to(config.device) no_decay = ['mean', 'std', 'bias'] + ['.bn%d.' % i for i in range(100)] grouped_parameters = [{ 'params': [], 'weight_decay': config.weight_decay }, { 'params': [], 'weight_decay': 0.0 }] for n, p in model.named_parameters(): if not any(nd in n for nd in no_decay): # print("Decay: %s" % n) grouped_parameters[0]['params'].append(p) else: # print("No Decay: %s" % n) grouped_parameters[1]['params'].append(p) optimizer = AdamW(grouped_parameters, lr=config.lr) if config.apex: model, optimizer = apex.amp.initialize(model, optimizer, opt_level="O1", verbosity=0) updates_per_epoch = len(train_ds) // config.batch_size num_updates = int(config.epochs * updates_per_epoch) scheduler = WarmupLinearSchedule(warmup=config.warmup, t_total=num_updates) # training loop smooth = 0.1 best_acc = 0.0 best_fn = None global_step = 0 for epoch in range(1, config.epochs + 1): smooth_loss = None smooth_accuracy = None model.train() train_loader = data.DataLoader(train_ds, batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers, pin_memory=config.pin, drop_last=True) progress = tqdm(total=len(train_ds), smoothing=0.01) if True: for i, (X, fns, hm, centers, classes) in enumerate(train_loader): X = X.to(config.device).float() hm = hm.to(config.device) centers = centers.to(config.device) classes = classes.to(config.device) hm_pred, classes_pred = model(X, centers=centers) loss = kuzushiji_loss(hm, centers, classes, hm_pred, classes_pred) if config.apex: with apex.amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() lr_this_step = None if (i + 1) % config.accumulation_step == 0: optimizer.step() optimizer.zero_grad() lr_this_step = config.lr * scheduler.get_lr( global_step, config.warmup) for param_group in optimizer.param_groups: param_group['lr'] = lr_this_step global_step += 1 smooth_loss = loss.item() if smooth_loss is None else \ smooth * loss.item() + (1. - smooth) * smooth_loss # print((y_true >= 0.5).sum().item()) accuracy = th.mean( ((th.sigmoid(hm_pred) >= 0.5) == (hm == 1)).to( th.float)).item() smooth_accuracy = accuracy if smooth_accuracy is None else \ smooth * accuracy + (1. - smooth) * smooth_accuracy progress.set_postfix( ep='%d/%d' % (epoch, config.epochs), loss='%.4f' % smooth_loss, accuracy='%.4f' % (smooth_accuracy), lr='%.6f' % (config.lr if lr_this_step is None else lr_this_step)) progress.update(len(X)) # skip validation if epoch not in [10, 20, 30, 40, 50]: if 1 < epoch <= 65: continue # validation loop model.eval() progress = tqdm(enumerate(val_loader), total=len(val_loader)) hm_correct, classes_correct = 0, 0 num_hm, num_classes = 0, 0 with th.no_grad(): for i, (X, fns, hm, centers, classes) in progress: X = X.to(config.device).float() hm = hm.cuda() centers = centers.cuda() classes = classes.cuda() hm_pred, classes_pred = model(X) hm_pred = th.sigmoid(hm_pred) classes_pred = th.nn.functional.softmax(classes_pred, 1) hm_cuda = hm.cuda() # PyTorch 1.2 has `bool` if hasattr(hm_cuda, 'bool'): hm_cuda = hm_cuda.bool() hm_correct += (hm_cuda == (hm_pred >= 0.5)).float().sum().item() num_hm += np.prod(hm.shape) num_samples = len(X) for sample_ind in range(num_samples): center_mask = centers[sample_ind, :, 0] != -1 per_image_letters = center_mask.sum().item() if per_image_letters == 0: continue num_classes += per_image_letters centers_per_img = centers[sample_ind][center_mask] classes_per_img = classes[sample_ind][center_mask] classes_per_img_pred = classes_pred[ sample_ind][:, centers_per_img[:, 1], centers_per_img[:, 0]].argmax(0) classes_correct += ( classes_per_img_pred == classes_per_img).sum().item() num_classes += per_image_letters val_hm_acc = hm_correct / num_hm val_classes_acc = classes_correct / num_classes summary_str = 'f%02d-ep-%04d-val_hm_acc-%.4f-val_classes_acc-%.4f' % ( config.fold, epoch, val_hm_acc, val_classes_acc) progress.write(summary_str) if val_classes_acc >= best_acc: weight_fn = os.path.join(config.logdir, summary_str + '.pth') progress.write("New best: %s" % weight_fn) th.save(model.state_dict(), weight_fn) best_acc = val_classes_acc best_fn = weight_fn fns = sorted( glob(os.path.join(config.logdir, 'f%02d-*.pth' % config.fold))) for fn in fns[:-config.n_keep]: os.remove(fn) # create submission test_ds = KuzushijiDataset(test_image_fns) test_loader = data.DataLoader(test_ds, batch_size=config.batch_size // 8, shuffle=False, num_workers=config.num_workers, pin_memory=False, drop_last=False) if best_fn is not None: model.load_state_dict(th.load(best_fn)) model.eval() sub = create_submission(model, test_loader, int_to_label, config, pred_zip=config.pred_zip) sub.to_csv(config.submission_fn, index=False) print("Wrote to: %s" % config.submission_fn) # create val submission val_fn = config.submission_fn.replace('.csv', '_VAL.csv') model.eval() sub = [] sub = create_submission(model, val_loader, int_to_label, config, pred_zip=config.pred_zip.replace( '.zip', '_VAL.zip')) sub.to_csv(val_fn, index=False) print("Wrote to: %s" % val_fn)
def __init__(self, training_dataloader, validate_dataloaer, optimizer, loss_func, model, num_epoch = 100, lr = 0.0002, gpus=None, pretrained_path = None, checkpoint_save_path = "best_model.pt", is_apex = False, is_scheduler = True): """ config data loader and gpus using for training. :param training_dataloader: :param validate_dataloaer: :param optimizer: :param loss_func: :param gpus: # :param _model: # """ # init data self.training_data = training_dataloader self.validate_data = validate_dataloaer self.optimizer = optimizer self.loss_func = loss_func self.model = model self.checkpoint_path = checkpoint_save_path # support vars self.name_model = model.__class__.__name__ self.writer = SummaryWriter() self.best_current_loss = 1000000000 self.current_delay_overfit = 0 self.is_apex = is_apex self.lr = lr self.num_epoch = num_epoch self.scheduler = WarmupLinearSchedule(warmup=0.03, t_total=len(training_dataloader)* num_epoch) if is_scheduler else None # load pre-trained _model os.environ["CUDA_VISIBLE_DEVICES"] = gpus if pretrained_path is not None: print("loaded model %s"%pretrained_path) # checkpoint = torch.load(pretrained_path, map_location='cpu') # source_state = checkpoint['state_dict'] # target_state = self.model.state_dict() # new_target_state = collections.OrderedDict() # for target_key, target_value in target_state.items(): # if target_key in source_state and source_state[target_key].size() == target_state[ # target_key].size(): # new_target_state[target_key] = source_state[target_key] # else: # new_target_state[target_key] = target_state[target_key] # print('[WARNING] Not found pre-trained parameters for {}'.format(target_key)) # # self.model.load_state_dict(new_target_state) self.model.load_state_dict(torch.load(pretrained_path)) # setting cuda if needed self.gpus, self.model = setting_cuda(gpus, self.model) if self.is_apex: # self.model, self.optimizer = apex.amp.initialize(self.model, self.optimizer, opt_level="O1", # verbosity=0) pass self.is_cuda = len(self.gpus) >= 1