Esempio n. 1
0
def train(args):
    if args.config_file != "":
        cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()

    output_dir = cfg.OUTPUT_DIR
    if output_dir and not os.path.exists(output_dir):
        os.makedirs(output_dir)
    shutil.copy(args.config_file, cfg.OUTPUT_DIR)

    num_gpus = torch.cuda.device_count()

    logger = setup_logger('reid_baseline', output_dir, 0)
    logger.info('Using {} GPUS'.format(num_gpus))
    logger.info(args)
    logger.info('Running with config:\n{}'.format(cfg))

    train_dl, val_dl, num_query, num_classes = make_dataloader(cfg, num_gpus)

    model = build_model(cfg, num_classes)
    # print(model)
    loss_func = make_loss(cfg, num_classes)

    trainer = BaseTrainer(cfg, model, train_dl, val_dl, loss_func, num_query,
                          num_gpus)

    for epoch in range(trainer.epochs):
        for batch in trainer.train_dl:
            trainer.step(batch)
            trainer.handle_new_batch()
        trainer.handle_new_epoch()
def main():
    test_ids = np.array([x[:-4] for x in os.listdir(args.test_folder) if x[-4:] == '.png'])

    MODEL_PATH = os.path.join(args.models_dir, args.network + args.alias)
    folds = [int(f) for f in args.fold.split(',')]

    print('Predicting Model:', args.network + args.alias)

    for fold in folds:
        K.clear_session()
        print('***************************** FOLD {} *****************************'.format(fold))

        # Initialize Model
        weights_path = os.path.join(MODEL_PATH, args.prediction_weights.format(fold))

        model, preprocess = get_model(args.network,
                                      input_shape=(args.input_size, args.input_size, 3),
                                      freeze_encoder=args.freeze_encoder)
        model.compile(optimizer=RMSprop(lr=args.learning_rate), loss=make_loss(args.loss_function),
                      metrics=[Kaggle_IoU_Precision])

        model.load_weights(weights_path)

        # Save test predictions to disk
        dir_path = os.path.join(MODEL_PATH, args.prediction_folder.format(fold))
        os.system("mkdir {}".format(dir_path))
        predict_test(model=model,
                     preds_path=dir_path,
                     ids=test_ids,
                     batch_size=args.batch_size * 2,
                     TTA='flip',
                     preprocess=preprocess)

        gc.collect()
Esempio n. 3
0
def unet_easy():
    # datatset
    dataset = NucleusDataset()
    dataset.load_nucleus(dataset_dir=train_data_dir)
    dataset.prepare()
    nucleus, masks = dataset.load_dataset()
    print(nucleus.shape, masks.shape)

    # model
    model = unet(input_size=(256, 256, 3), pre_weights=None, channels=1)
    # model.summary()
    # freeze_model(model, "input_1")
    optimizer = RMSprop(lr=0.001)
    best_model_file = '{}/best_{}.h5'.format("results", "unet")
    checkpointer = ModelCheckpoint(
        'trained_model_weight/model-dsbowl2018-1.h5',
        verbose=1,
        save_best_only=True)
    model.compile(optimizer=optimizer,
                  loss=make_loss('bce_dice'),
                  metrics=[binary_crossentropy, hard_dice_coef])
    results = model.fit(nucleus,
                        masks,
                        validation_split=0.1,
                        batch_size=16,
                        epochs=50,
                        callbacks=[checkpointer])
Esempio n. 4
0
def main():
    output_dir = cfg.OUTPUT_DIR
    if output_dir and not os.path.exists(output_dir):
        os.makedirs(output_dir)
    num_gpus = torch.cuda.device_count()
    logger = setup_logger('reid_baseline', output_dir, 0)
    logger.info('Using {} GPUS'.format(num_gpus))
    logger.info('Running with config:\n{}'.format(cfg))
    train_dl, val_dl, num_query, num_classes = make_dataloader(cfg, num_gpus)
    model = build_model(cfg, num_classes)
    loss = make_loss(cfg, num_classes)
    trainer = BaseTrainer(cfg, model, train_dl, val_dl, loss, num_query,
                          num_gpus)
    for epoch in range(trainer.epochs):
        for batch in trainer.train_dl:
            trainer.step(batch)
            trainer.handle_new_batch()
        trainer.handle_new_epoch()
Esempio n. 5
0
def keras_fit_generator():

    kfolds = [0]
    # kfolds = [2, 3, 4]
    for fold in kfolds:

        K.clear_session()
        print('fold = {}'.format(fold))
        print('begin load data')
        X_train, y_train, X_val, y_val = load_data(fold)
        print(X_train.shape, y_train.shape, X_val.shape, y_val.shape)
        # (1250, 256, 256, 1) (1250, 256, 256, 1) (127, 256, 256, 1) (127, 256, 256, 1)
        print('load data over')


        model, process = get_model(network=network, input_shape=(X_train.shape[1], X_train.shape[2], X_train.shape[3]), freeze_encoder=False)

        model.load_weights(pretrain_weight + str(fold) + '.hdf5')

        val_gen = Generator(X_val, y_val, batch_size=len(y_val), shuffle=True, aug=True, process=process)
        X_val_steps, y_val_steps = next(val_gen.generator)

        train_gen = Generator(X_train, y_train, batch_size=batch_size, shuffle=True, aug=True, process=process)

        model.compile(optimizer=Adam(lr=learning_rate), loss=make_loss(loss_name=loss_function), metrics=[dice_coef])

        model.summary()

        c_backs = get_callback(callback, fold, num_sample=len(X_train))

        model.fit_generator(
                            train_gen.generator,
                            steps_per_epoch=(len(X_train)//batch_size)*2,
                            epochs=epochs,
                            verbose=1,
                            shuffle=True,
                            validation_data=(X_val_steps, y_val_steps),
                            callbacks=c_backs,
                            use_multiprocessing=False)
        gc.collect()
Esempio n. 6
0
def main():
    parser = argparse.ArgumentParser(description="ReID Baseline Training")
    parser.add_argument("--config_file",
                        default="",
                        help="path to config file",
                        type=str)
    parser.add_argument("opts",
                        help="Modify config options using the command-line",
                        default=None,
                        nargs=argparse.REMAINDER)
    args = parser.parse_args()
    if args.config_file != "":
        cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()

    output_dir = cfg.OUTPUT_DIR
    if output_dir and not os.path.exists(output_dir):
        os.makedirs(output_dir)
    num_gpus = torch.cuda.device_count()
    logger = setup_logger('reid_baseline', output_dir, 0)
    logger.info('Using {} GPUS'.format(num_gpus))
    logger.info('Running with config:\n{}'.format(cfg))
    train_dl, val_dl, num_query, num_classes = make_dataloader(cfg, num_gpus)
    model = build_model(cfg, num_classes)
    loss = make_loss(cfg, num_classes)
    trainer = SGDTrainer(cfg, model, train_dl, val_dl, loss, num_query,
                         num_gpus)
    logger.info('train transform: \n{}'.format(train_dl.dataset.transform))
    logger.info('valid transform: \n{}'.format(val_dl.dataset.transform))
    logger.info(type(model))
    logger.info(loss)
    logger.info(trainer)
    for epoch in range(trainer.epochs):
        for batch in trainer.train_dl:
            trainer.step(batch)
            trainer.handle_new_batch()
        trainer.handle_new_epoch()
Esempio n. 7
0
def main():
    parser = argparse.ArgumentParser(description="ReID Baseline Training")
    parser.add_argument("--config_file", default="", help="path to config file", type=str)
    parser.add_argument("opts", help="Modify config options using the command-line", default=None,nargs=argparse.REMAINDER)
    args = parser.parse_args()
    if args.config_file != "":
        cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()

    output_dir = cfg.OUTPUT_DIR
    if output_dir and not os.path.exists(output_dir):
        os.makedirs(output_dir)
    num_gpus = torch.cuda.device_count()
    logger = setup_logger('reid_baseline', output_dir, 0)
    logger.info('Using {} GPUS'.format(num_gpus))
    logger.info('Running with config:\n{}'.format(cfg))
    if cfg.INPUT.SEPNORM.USE:
        train_dl, val_dl, num_query, num_classes = make_sepnorm_dataloader(cfg, num_gpus)
    elif cfg.DATASETS.EXEMPLAR.USE:
        train_dl, val_dl, num_query, num_classes,exemplar_dl = make_dataloader(cfg, num_gpus)
    else:
        train_dl, val_dl, num_query, num_classes = make_dataloader(cfg, num_gpus)

    model = build_model(cfg, num_classes)
    loss = make_loss(cfg, num_classes)
    if cfg.SOLVER.CENTER_LOSS.USE == True:
        trainer = CenterTrainer(cfg, model, train_dl, val_dl,
                      loss, num_query, num_gpus)
    else:
        if cfg.SOLVER.MIXUP.USE:
            trainer = NegMixupTrainer(cfg, model, train_dl, val_dl,
                              loss, num_query, num_gpus)
        elif cfg.DATASETS.EXEMPLAR.USE:
            if cfg.DATASETS.EXEMPLAR.MEMORY.USE:
                trainer = ExemplarMemoryTrainer(cfg, model, train_dl, val_dl,exemplar_dl,
                                  loss, num_query, num_gpus)
            else:
                trainer = UIRLTrainer(cfg, model, train_dl, val_dl,exemplar_dl,
                                  loss, num_query, num_gpus)
        elif cfg.DATASETS.HIST_LABEL.USE:
            trainer = HistLabelTrainer(cfg, model, train_dl, val_dl,
                    loss, num_query, num_gpus)
        else:
            trainer = BaseTrainer(cfg, model, train_dl, val_dl,
                              loss, num_query, num_gpus)
    if cfg.INPUT.SEPNORM.USE:
        logger.info('train transform0: \n{}'.format(train_dl.dataset.transform0))
        logger.info('train transform1: \n{}'.format(train_dl.dataset.transform1))

        logger.info('valid transform0: \n{}'.format(val_dl.dataset.transform0))
        logger.info('valid transform1: \n{}'.format(val_dl.dataset.transform1))

    else:
        logger.info('train transform: \n{}'.format(train_dl.dataset.transform))
        logger.info('valid transform: \n{}'.format(val_dl.dataset.transform))
    logger.info(type(model))
    logger.info(loss)
    logger.info(trainer)
    for epoch in range(trainer.epochs):
        for batch in trainer.train_dl:
            trainer.step(batch)
            trainer.handle_new_batch()
        trainer.handle_new_epoch()
Esempio n. 8
0
        ZeroCenter(),
        LinearSymplecticTwoByTwo(),
        SymplecticAdditiveCoupling(shift_model=IrrotationalMLP())
    ])
    #SymplecticAdditiveCoupling(shift_model=MLP())])
T = Chain(stack)

step = tf.get_variable("global_step", [],
                       tf.int64,
                       tf.zeros_initializer(),
                       trainable=False)

with tf.Session() as sess:

    z = make_data(settings, sess)

loss = make_loss(settings, T, z)

train_op = make_train_op(settings, loss, step)

# sess.run(tf.global_variables_initializer())

# Set the ZeroCenter bijectors to training mode:
for i, bijector in enumerate(T.bijectors):
    if hasattr(bijector, 'is_training'):
        T.bijectors[i].is_training = True

tf.contrib.training.train(train_op,
                          logdir=settings['log_dir'],
                          save_checkpoint_secs=60)
def main():
  
    #读入数据,并存为列表,每一个item为列表中的一个元素
    train = pd.read_csv(args.folds_csv)
    #设置的模型保存路径
    MODEL_PATH = os.path.join(args.models_dir, args.network + args.alias)
    #将数据分折,原始输入数据就是预先分好折的,这个是存储的分折信息,12345列表元素
    folds = [int(f) for f in args.fold.split(',')]

    print('Training Model:', args.network + args.alias)

    for fold in folds:

        K.clear_session()
        print('***************************** FOLD {} *****************************'.format(fold))

        if fold == 0:
            if os.path.isdir(MODEL_PATH):
                raise ValueError('Such Model already exists')
            os.system("mkdir {}".format(MODEL_PATH))

        # Train/Validation sampling
        df_train = train[train.fold != fold].copy().reset_index(drop=True)
        df_valid = train[train.fold == fold].copy().reset_index(drop=True)

        # Train on pseudolabels only
        if args.pseudolabels_dir != '':
            pseudolabels = pd.read_csv(args.pseudolabels_csv)
            df_train = pseudolabels.sample(frac=1, random_state=13).reset_index(drop=True)

        # Keep only non-black images
        ids_train, ids_valid = df_train[df_train.unique_pixels > 1].id.values, df_valid[
            df_valid.unique_pixels > 1].id.values

        print('Training on {} samples'.format(ids_train.shape[0]))
        print('Validating on {} samples'.format(ids_valid.shape[0]))

        # Initialize model
        weights_path = os.path.join(MODEL_PATH, 'fold_{fold}.hdf5'.format(fold=fold))

        # Get the model
        model, preprocess = get_model(args.network,
                                      input_shape=(args.input_size, args.input_size, 3),
                                      freeze_encoder=args.freeze_encoder)

        # LB metric threshold
        def lb_metric(y_true, y_pred):
            return Kaggle_IoU_Precision(y_true, y_pred, threshold=0 if args.loss_function == 'lovasz' else 0.5)

        model.compile(optimizer=RMSprop(lr=args.learning_rate), loss=make_loss(args.loss_function),
                      metrics=[lb_metric])

        if args.pretrain_weights is None:
            print('No weights passed, training from scratch')
        else:
            wp = args.pretrain_weights.format(fold)
            print('Loading weights from {}'.format(wp))
            model.load_weights(wp, by_name=True)

        # Get augmentations
        augs = get_augmentations(args.augmentation_name, p=args.augmentation_prob)

        # Data generator
        dg = SegmentationDataGenerator(input_shape=(args.input_size, args.input_size),
                                       batch_size=args.batch_size,
                                       augs=augs,
                                       preprocess=preprocess)

        train_generator = dg.train_batch_generator(ids_train)
        validation_generator = dg.evaluation_batch_generator(ids_valid)

        # Get callbacks
        callbacks = get_callback(args.callback,
                                 weights_path=weights_path,
                                 fold=fold)

        # Fit the model with Generators:
        model.fit_generator(generator=ThreadsafeIter(train_generator),
                            steps_per_epoch=ids_train.shape[0] // args.batch_size * 2,
                            epochs=args.epochs,
                            callbacks=callbacks,
                            validation_data=ThreadsafeIter(validation_generator),
                            validation_steps=np.ceil(ids_valid.shape[0] / args.batch_size),
                            workers=args.num_workers)

        gc.collect()
Esempio n. 10
0
def main():
    if args.crop_size:
        print('Using crops of shape ({}, {})'.format(args.crop_size, args.crop_size))
    else:
        print('Using full size images')

    all_ids = np.array(generate_ids(args.data_dirs, args.clahe))
    np.random.seed(args.seed)
    kfold = KFold(n_splits=args.n_folds, shuffle=True)

    splits = [s for s in kfold.split(all_ids)]
    folds = [int(f) for f in args.fold.split(",")]
    for fold in folds:
        encoded_alias = encode_params(args.clahe, args.preprocessing_function, args.stretch_and_mean)
        city = "all"
        if args.city:
            city = args.city.lower()
        best_model_file = '{}/{}_{}_{}.h5'.format(args.models_dir, encoded_alias, city, args.network)
        channels = 8
        if args.ohe_city:
            channels = 12
        model = make_model(args.network, (None, None, channels))

        if args.weights is None:
            print('No weights passed, training from scratch')
        else:
            print('Loading weights from {}'.format(args.weights))
            model.load_weights(args.weights, by_name=True)
        freeze_model(model, args.freeze_till_layer)

        optimizer = RMSprop(lr=args.learning_rate)
        if args.optimizer:
            if args.optimizer == 'rmsprop':
                optimizer = RMSprop(lr=args.learning_rate)
            elif args.optimizer == 'adam':
                optimizer = Adam(lr=args.learning_rate)
            elif args.optimizer == 'sgd':
                optimizer = SGD(lr=args.learning_rate, momentum=0.9, nesterov=True)

        train_ind, test_ind = splits[fold]
        train_ids = all_ids[train_ind]
        val_ids = all_ids[test_ind]
        if args.city:
            val_ids = [id for id in val_ids if args.city in id[0]]
            train_ids = [id for id in train_ids if args.city in id[0]]
        print('Training fold #{}, {} in train_ids, {} in val_ids'.format(fold, len(train_ids), len(val_ids)))
        masks_gt = get_groundtruth(args.data_dirs)
        if args.clahe:
            template = 'CLAHE-MUL-PanSharpen/MUL-PanSharpen_{id}.tif'
        else:
            template = 'MUL-PanSharpen/MUL-PanSharpen_{id}.tif'

        train_generator = MULSpacenetDataset(
            data_dirs=args.data_dirs,
            wdata_dir=args.wdata_dir,
            clahe=args.clahe,
            batch_size=args.batch_size,
            image_ids=train_ids,
            masks_dict=masks_gt,
            image_name_template=template,
            seed=args.seed,
            ohe_city=args.ohe_city,
            stretch_and_mean=args.stretch_and_mean,
            preprocessing_function=args.preprocessing_function,
            crops_per_image=args.crops_per_image,
            crop_shape=(args.crop_size, args.crop_size),
            random_transformer=RandomTransformer(horizontal_flip=True, vertical_flip=True),
        )

        val_generator = MULSpacenetDataset(
            data_dirs=args.data_dirs,
            wdata_dir=args.wdata_dir,
            clahe=args.clahe,
            batch_size=1,
            image_ids=val_ids,
            image_name_template=template,
            masks_dict=masks_gt,
            seed=args.seed,
            ohe_city=args.ohe_city,
            stretch_and_mean=args.stretch_and_mean,
            preprocessing_function=args.preprocessing_function,
            shuffle=False,
            crops_per_image=1,
            crop_shape=(1280, 1280),
            random_transformer=None
        )
        best_model = ModelCheckpoint(filepath=best_model_file, monitor='val_dice_coef_clipped',
                                     verbose=1,
                                     mode='max',
                                     save_best_only=False,
                                     save_weights_only=True)
        model.compile(loss=make_loss(args.loss_function),
                      optimizer=optimizer,
                      metrics=[dice_coef, binary_crossentropy, ceneterline_loss, dice_coef_clipped])

        def schedule_steps(epoch, steps):
            for step in steps:
                if step[1] > epoch:
                    print("Setting learning rate to {}".format(step[0]))
                    return step[0]
            print("Setting learning rate to {}".format(steps[-1][0]))
            return steps[-1][0]

        callbacks = [best_model, EarlyStopping(patience=20, verbose=1, monitor='val_dice_coef_clipped', mode='max')]

        if args.schedule is not None:
            steps = [(float(step.split(":")[0]), int(step.split(":")[1])) for step in args.schedule.split(",")]
            lrSchedule = LearningRateScheduler(lambda epoch: schedule_steps(epoch, steps))
            callbacks.insert(0, lrSchedule)

        if args.clr is not None:
            clr_params = args.clr.split(',')
            base_lr = float(clr_params[0])
            max_lr = float(clr_params[1])
            step = int(clr_params[2])
            mode = clr_params[3]
            clr = CyclicLR(base_lr=base_lr, max_lr=max_lr, step_size=step, mode=mode)
            callbacks.append(clr)

        steps_per_epoch = len(all_ids) / args.batch_size + 1
        if args.steps_per_epoch:
            steps_per_epoch = args.steps_per_epoch

        model.fit_generator(
            train_generator,
            steps_per_epoch=steps_per_epoch,
            epochs=args.epochs,
            validation_data=val_generator,
            validation_steps=len(val_ids),
            callbacks=callbacks,
            max_queue_size=30,
            verbose=1,
            workers=args.num_workers)

        del model
        K.clear_session()
        gc.collect()
Esempio n. 11
0
def main():
    mask_dir = os.path.join(args.dataset_dir, args.train_mask_dir_name)
    val_mask_dir = os.path.join(args.dataset_dir, args.val_mask_dir_name)

    train_data_dir = os.path.join(args.dataset_dir, args.train_data_dir_name)
    val_data_dir = os.path.join(args.dataset_dir, args.val_data_dir_name)

    # mask_dir = 'data/train/masks_fail'
    # val_mask_dir = 'data/val/masks'
    #
    # train_data_dir = 'data/train/images_fail'
    # val_data_dir = 'data/val/images

    if args.net_alias is not None:
        formatted_net_alias = '-{}-'.format(args.net_alias)

    best_model_file =\
        '{}/{}{}loss-{}-fold_{}-{}{:.6f}'.format(args.models_dir, args.network, formatted_net_alias, args.loss_function, args.fold, args.input_width, args.learning_rate) +\
        '-{epoch:d}-{val_loss:0.7f}-{val_dice_coef:0.7f}-{val_mean_io:0.7f}-{val_dice_coef_clipped:0.7f}.h5'
    if args.edges:
        ch = 5
    else:
        ch = 3
    model = make_model((None, None, args.stacked_channels + ch))
    freeze_model(model, args.freeze_till_layer)

    if args.weights is None:
        print('No weights passed, training from scratch')
    else:
        print('Loading weights from {}'.format(args.weights))
        model.load_weights(args.weights, by_name=True)

    optimizer = Adam(lr=args.learning_rate)

    if args.show_summary:
        model.summary()

    model.compile(loss=make_loss(args.loss_function),
                  optimizer=optimizer,
                  metrics=[
                      dice_coef_border, dice_coef, binary_crossentropy,
                      dice_coef_clipped, mean_iou
                  ])

    crop_size = None

    if args.use_crop:
        crop_size = (args.input_height, args.input_width)
        print('Using crops of shape ({}, {})'.format(args.input_height,
                                                     args.input_width))
    else:
        print('Using full size images, --use_crop=True to do crops')

    # folds_df = pd.read_csv(os.path.join(args.dataset_dir, args.folds_source))
    # train_ids = generate_filenames(folds_df[folds_df.fold != args.fold]['id'])
    # val_ids = generate_filenames(folds_df[folds_df.fold == args.fold]['id'])
    train_df = pd.read_csv('../data/train_df.csv')
    val_df = pd.read_csv('../data/val_df.csv')
    train_ids = [img + '.png' for img in train_df['id'].values]
    val_ids = [img + '.png' for img in val_df['id'].values]
    # train_ids = os.listdir(train_data_dir)
    # val_ids = os.listdir(val_data_dir)

    print('Training fold #{}, {} in train_ids, {} in val_ids'.format(
        args.fold, len(train_ids), len(val_ids)))

    train_generator = build_batch_generator(train_ids,
                                            img_dir=train_data_dir,
                                            batch_size=args.batch_size,
                                            shuffle=True,
                                            out_size=(args.out_height,
                                                      args.out_width),
                                            crop_size=crop_size,
                                            mask_dir=mask_dir,
                                            aug=True)

    val_generator = build_batch_generator(val_ids,
                                          img_dir=val_data_dir,
                                          batch_size=args.batch_size,
                                          shuffle=False,
                                          out_size=(args.out_height,
                                                    args.out_width),
                                          crop_size=crop_size,
                                          mask_dir=val_mask_dir,
                                          aug=False)

    best_model = ModelCheckpoint(best_model_file,
                                 monitor='val_loss',
                                 verbose=1,
                                 save_best_only=False,
                                 save_weights_only=True)

    callbacks = [
        best_model,
        EarlyStopping(patience=45, verbose=10),
        TensorBoard(log_dir='./logs',
                    histogram_freq=0,
                    write_graph=True,
                    write_images=True)
    ]
    if args.clr is not None:
        clr_params = args.clr.split(',')
        base_lr = float(clr_params[0])
        max_lr = float(clr_params[1])
        step = int(clr_params[2])
        mode = clr_params[3]
        clr = CyclicLR(base_lr=base_lr,
                       max_lr=max_lr,
                       step_size=step,
                       mode=mode)
        callbacks.append(clr)
    model.fit_generator(ThreadsafeIter(train_generator),
                        steps_per_epoch=len(train_ids) / args.batch_size + 1,
                        epochs=args.epochs,
                        validation_data=ThreadsafeIter(val_generator),
                        validation_steps=len(val_ids) / args.batch_size + 1,
                        callbacks=callbacks,
                        max_queue_size=50,
                        workers=4)
Esempio n. 12
0
def main():
    if args.crop_size:
        print('Using crops of shape ({}, {})'.format(args.crop_size,
                                                     args.crop_size))
    else:
        print('Using full size images')

    all_ids = np.array(generate_ids(args.data_dirs, args.clahe))
    np.random.seed(args.seed)
    kfold = KFold(n_splits=args.n_folds, shuffle=True)

    splits = [s for s in kfold.split(all_ids)]
    folds = [int(f) for f in args.fold.split(",")]
    for fold in folds:
        encoded_alias = encode_params(args.clahe, args.preprocessing_function,
                                      args.stretch_and_mean)
        city = "all"
        if args.city:
            city = args.city.lower()
        best_model_file = '{}/{}_{}_{}.h5'.format(args.models_dir,
                                                  encoded_alias, city,
                                                  args.network)
        channels = 8
        if args.ohe_city:
            channels = 12
        model = make_model(args.network, (None, None, channels))

        if args.weights is None:
            print('No weights passed, training from scratch')
        else:
            print('Loading weights from {}'.format(args.weights))
            model.load_weights(args.weights, by_name=True)
        freeze_model(model, args.freeze_till_layer)

        optimizer = RMSprop(lr=args.learning_rate)
        if args.optimizer:
            if args.optimizer == 'rmsprop':
                optimizer = RMSprop(lr=args.learning_rate)
            elif args.optimizer == 'adam':
                optimizer = Adam(lr=args.learning_rate)
            elif args.optimizer == 'sgd':
                optimizer = SGD(lr=args.learning_rate,
                                momentum=0.9,
                                nesterov=True)

        train_ind, test_ind = splits[fold]
        train_ids = all_ids[train_ind]
        val_ids = all_ids[test_ind]
        if args.city:
            val_ids = [id for id in val_ids if args.city in id[0]]
            train_ids = [id for id in train_ids if args.city in id[0]]
        print('Training fold #{}, {} in train_ids, {} in val_ids'.format(
            fold, len(train_ids), len(val_ids)))
        masks_gt = get_groundtruth(args.data_dirs)
        if args.clahe:
            template = 'CLAHE-MUL-PanSharpen/MUL-PanSharpen_{id}.tif'
        else:
            template = 'MUL-PanSharpen/MUL-PanSharpen_{id}.tif'

        train_generator = MULSpacenetDataset(
            data_dirs=args.data_dirs,
            wdata_dir=args.wdata_dir,
            clahe=args.clahe,
            batch_size=args.batch_size,
            image_ids=train_ids,
            masks_dict=masks_gt,
            image_name_template=template,
            seed=args.seed,
            ohe_city=args.ohe_city,
            stretch_and_mean=args.stretch_and_mean,
            preprocessing_function=args.preprocessing_function,
            crops_per_image=args.crops_per_image,
            crop_shape=(args.crop_size, args.crop_size),
            random_transformer=RandomTransformer(horizontal_flip=True,
                                                 vertical_flip=True),
        )

        val_generator = MULSpacenetDataset(
            data_dirs=args.data_dirs,
            wdata_dir=args.wdata_dir,
            clahe=args.clahe,
            batch_size=1,
            image_ids=val_ids,
            image_name_template=template,
            masks_dict=masks_gt,
            seed=args.seed,
            ohe_city=args.ohe_city,
            stretch_and_mean=args.stretch_and_mean,
            preprocessing_function=args.preprocessing_function,
            shuffle=False,
            crops_per_image=1,
            crop_shape=(1280, 1280),
            random_transformer=None)
        best_model = ModelCheckpoint(filepath=best_model_file,
                                     monitor='val_dice_coef_clipped',
                                     verbose=1,
                                     mode='max',
                                     save_best_only=False,
                                     save_weights_only=True)
        model.compile(loss=make_loss(args.loss_function),
                      optimizer=optimizer,
                      metrics=[
                          dice_coef, binary_crossentropy, ceneterline_loss,
                          dice_coef_clipped
                      ])

        def schedule_steps(epoch, steps):
            for step in steps:
                if step[1] > epoch:
                    print("Setting learning rate to {}".format(step[0]))
                    return step[0]
            print("Setting learning rate to {}".format(steps[-1][0]))
            return steps[-1][0]

        callbacks = [
            best_model,
            EarlyStopping(patience=20,
                          verbose=1,
                          monitor='val_dice_coef_clipped',
                          mode='max')
        ]

        if args.schedule is not None:
            steps = [(float(step.split(":")[0]), int(step.split(":")[1]))
                     for step in args.schedule.split(",")]
            lrSchedule = LearningRateScheduler(
                lambda epoch: schedule_steps(epoch, steps))
            callbacks.insert(0, lrSchedule)

        if args.clr is not None:
            clr_params = args.clr.split(',')
            base_lr = float(clr_params[0])
            max_lr = float(clr_params[1])
            step = int(clr_params[2])
            mode = clr_params[3]
            clr = CyclicLR(base_lr=base_lr,
                           max_lr=max_lr,
                           step_size=step,
                           mode=mode)
            callbacks.append(clr)

        steps_per_epoch = len(all_ids) / args.batch_size + 1
        if args.steps_per_epoch:
            steps_per_epoch = args.steps_per_epoch

        model.fit_generator(train_generator,
                            steps_per_epoch=steps_per_epoch,
                            epochs=args.epochs,
                            validation_data=val_generator,
                            validation_steps=len(val_ids),
                            callbacks=callbacks,
                            max_queue_size=30,
                            verbose=1,
                            workers=args.num_workers)

        del model
        K.clear_session()
        gc.collect()
Esempio n. 13
0
    random.seed(1234)
    torch.backends.cudnn.deterministic = True
    cudnn.benchmark = True

    #CONFIG PARSER
    config = get_args()
    output_path = config.output_path
    make_dirs(output_path)
    logger = setup_logger('reid_baseline',output_path,if_train=True)

    train_loader,train_gen_loader, val_loader, num_query, num_classes = make_dataloader(config)
    model = Backbone(num_classes,config)
    # if config.pretrain:
    #     model.load_param_finetune(config.m_pretrain_path)

    loss_func, center_criterion = make_loss(config, num_classes=num_classes)
    optimizer, optimizer_center = make_optimizer( model, center_criterion)
    scheduler = WarmupMultiStepLR(optimizer, [40,70], 0.1,
                                  0.01,
                                  10, 'linear')

    log_period = config.log_interval
    checkpoint_period = config.save_model_interval
    eval_period = config.test_interval

    device = "cuda"
    epochs = 80

    logger = logging.getLogger("reid_baseline.train")
    logger.info('start training')
Esempio n. 14
0
def train(args):
    if args.batch_size % args.num_instance != 0:
        new_batch_size = (args.batch_size //
                          args.num_instance) * args.num_instance
        print(
            f"given batch size is {args.batch_size} and num_instances is {args.num_instance}."
            +
            f"Batch size must be divided into {args.num_instance}. Batch size will be replaced into {new_batch_size}"
        )
        args.batch_size = new_batch_size

    # prepare dataset
    train_loader, val_loader, num_query, train_data_len, num_classes = make_data_loader(
        args)

    model = build_model(args, num_classes)
    print("model size: {:.5f}M".format(
        sum(p.numel() for p in model.parameters()) / 1e6))
    loss_fn, center_criterion = make_loss(args, num_classes)
    optimizer, optimizer_center = make_optimizer(args, model, center_criterion)

    if args.cuda:
        model = model.cuda()
        if args.amp:
            if args.center_loss:
                model, [optimizer, optimizer_center] = \
                    amp.initialize(model, [optimizer, optimizer_center], opt_level="O1")
            else:
                model, optimizer = amp.initialize(model,
                                                  optimizer,
                                                  opt_level="O1")

        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.cuda()
        if args.center_loss:
            center_criterion = center_criterion.cuda()
            for state in optimizer_center.state.values():
                for k, v in state.items():
                    if isinstance(v, torch.Tensor):
                        state[k] = v.cuda()

    model_state_dict = model.state_dict()
    optim_state_dict = optimizer.state_dict()
    if args.center_loss:
        optim_center_state_dict = optimizer_center.state_dict()
        center_state_dict = center_criterion.state_dict()

    reid_evaluator = ReIDEvaluator(args, model, num_query)

    start_epoch = 0
    global_step = 0
    if args.pretrain != '':  # load pre-trained model
        weights = torch.load(args.pretrain)
        model_state_dict = weights["state_dict"]

        model.load_state_dict(model_state_dict)
        if args.center_loss:
            center_criterion.load_state_dict(
                torch.load(args.pretrain.replace(
                    'model', 'center_param'))["state_dict"])

        if args.resume:
            start_epoch = weights["epoch"]
            global_step = weights["global_step"]

            optimizer.load_state_dict(
                torch.load(args.pretrain.replace('model',
                                                 'optimizer'))["state_dict"])
            if args.center_loss:
                optimizer_center.load_state_dict(
                    torch.load(
                        args.pretrain.replace(
                            'model', 'optimizer_center'))["state_dict"])
        print(f'Start epoch: {start_epoch}, Start step: {global_step}')

    scheduler = WarmupMultiStepLR(optimizer, args.steps, args.gamma,
                                  args.warmup_factor, args.warmup_step,
                                  "linear",
                                  -1 if start_epoch == 0 else start_epoch)

    current_epoch = start_epoch
    best_epoch = 0
    best_rank1 = 0
    best_mAP = 0
    if args.resume:
        rank, mAP = reid_evaluator.evaluate(val_loader)
        best_rank1 = rank[0]
        best_mAP = mAP
        best_epoch = current_epoch + 1

    batch_time = AverageMeter()
    total_losses = AverageMeter()

    model_save_dir = os.path.join(args.save_dir, 'ckpts')
    os.makedirs(model_save_dir, exist_ok=True)

    summary_writer = SummaryWriter(log_dir=os.path.join(
        args.save_dir, "tensorboard_log"),
                                   purge_step=global_step)

    def summary_loss(score, feat, labels, top_name='global'):
        loss = 0.0
        losses = loss_fn(score, feat, labels)
        for loss_name, loss_val in losses.items():
            if loss_name.lower() == "accuracy":
                summary_writer.add_scalar(f"Score/{top_name}/triplet",
                                          loss_val, global_step)
                continue
            if "dist" in loss_name.lower():
                summary_writer.add_histogram(f"Distance/{loss_name}", loss_val,
                                             global_step)
                continue
            loss += loss_val
            summary_writer.add_scalar(f"losses/{top_name}/{loss_name}",
                                      loss_val, global_step)

        ohe_labels = torch.zeros_like(score)
        ohe_labels.scatter_(1, labels.unsqueeze(1), 1.0)

        cls_score = torch.softmax(score, dim=1)
        cls_score = torch.sum(cls_score * ohe_labels, dim=1).mean()
        summary_writer.add_scalar(f"Score/{top_name}/X-entropy", cls_score,
                                  global_step)

        return loss

    def save_weights(file_name, eph, steps):
        torch.save(
            {
                "state_dict": model_state_dict,
                "epoch": eph + 1,
                "global_step": steps
            }, file_name)
        torch.save({"state_dict": optim_state_dict},
                   file_name.replace("model", "optimizer"))
        if args.center_loss:
            torch.save({"state_dict": center_state_dict},
                       file_name.replace("model", "optimizer_center"))
            torch.save({"state_dict": optim_center_state_dict},
                       file_name.replace("model", "center_param"))

    # training start
    for epoch in range(start_epoch, args.max_epoch):
        model.train()
        t0 = time.time()
        for i, (inputs, labels, _, _) in enumerate(train_loader):
            if args.cuda:
                inputs = inputs.cuda()
                labels = labels.cuda()

            cls_scores, features = model(inputs, labels)

            # losses
            total_loss = summary_loss(cls_scores[0], features[0], labels,
                                      'global')
            if args.use_local_feat:
                total_loss += summary_loss(cls_scores[1], features[1], labels,
                                           'local')

            optimizer.zero_grad()
            if args.center_loss:
                optimizer_center.zero_grad()

            # backward with global loss
            if args.amp:
                optimizers = [optimizer]
                if args.center_loss:
                    optimizers.append(optimizer_center)
                with amp.scale_loss(total_loss, optimizers) as scaled_loss:
                    scaled_loss.backward()
            else:
                with torch.autograd.detect_anomaly():
                    total_loss.backward()

            # optimization
            optimizer.step()
            if args.center_loss:
                for name, param in center_criterion.named_parameters():
                    try:
                        param.grad.data *= (1. / args.center_loss_weight)
                    except AttributeError:
                        continue
                optimizer_center.step()

            batch_time.update(time.time() - t0)
            total_losses.update(total_loss.item())

            # learning_rate
            current_lr = optimizer.param_groups[0]['lr']
            summary_writer.add_scalar("lr", current_lr, global_step)

            t0 = time.time()

            if (i + 1) % args.log_period == 0:
                print(
                    f"Epoch: [{epoch}][{i+1}/{train_data_len}]  " +
                    f"Batch Time {batch_time.val:.3f} ({batch_time.mean:.3f})  "
                    +
                    f"Total_loss {total_losses.val:.3f} ({total_losses.mean:.3f})"
                )
            global_step += 1

        print(
            f"Epoch: [{epoch}]\tEpoch Time {batch_time.sum:.3f} s\tLoss {total_losses.mean:.3f}\tLr {current_lr:.2e}"
        )

        if args.eval_period > 0 and (epoch + 1) % args.eval_period == 0 or (
                epoch + 1) == args.max_epoch:
            rank, mAP = reid_evaluator.evaluate(
                val_loader,
                mode="retrieval" if args.dataset_name == "cub200" else "reid")

            rank_string = ""
            for r in (1, 2, 4, 5, 8, 10, 16, 20):
                rank_string += f"Rank-{r:<3}: {rank[r-1]:.1%}"
                if r != 20:
                    rank_string += "    "
            summary_writer.add_text("Recall@K", rank_string, global_step)
            summary_writer.add_scalar("Rank-1", rank[0], (epoch + 1))

            rank1 = rank[0]
            is_best = rank1 > best_rank1
            if is_best:
                best_rank1 = rank1
                best_mAP = mAP
                best_epoch = epoch + 1

            if (epoch + 1) % args.save_period == 0 or (epoch +
                                                       1) == args.max_epoch:
                pth_file_name = os.path.join(
                    model_save_dir,
                    f"{args.backbone}_model_{epoch + 1}.pth.tar")
                save_weights(pth_file_name, eph=epoch, steps=global_step)

            if is_best:
                pth_file_name = os.path.join(
                    model_save_dir, f"{args.backbone}_model_best.pth.tar")
                save_weights(pth_file_name, eph=epoch, steps=global_step)

        # end epoch
        current_epoch += 1

        batch_time.reset()
        total_losses.reset()
        torch.cuda.empty_cache()

        # update learning rate
        scheduler.step()

    print(f"Best rank-1 {best_rank1:.1%}, achived at epoch {best_epoch}")
    summary_writer.add_hparams(
        {
            "dataset_name": args.dataset_name,
            "triplet_dim": args.triplet_dim,
            "margin": args.margin,
            "base_lr": args.base_lr,
            "use_attn": args.use_attn,
            "use_mask": args.use_mask,
            "use_local_feat": args.use_local_feat
        }, {
            "mAP": best_mAP,
            "Rank1": best_rank1
        })
Esempio n. 15
0
def main():
    if args.crop_size:
        print('Using crops of shape ({}, {})'.format(args.crop_size,
                                                     args.crop_size))
    else:
        print('Using full size images')
    folds = [int(f) for f in args.fold.split(",")]
    for fold in folds:
        channels = 3
        if args.multi_gpu:
            with K.tf.device("/cpu:0"):
                model = make_model(args.network, (None, None, 3))
        else:
            model = make_model(args.network, (None, None, channels))
        if args.weights is None:
            print('No weights passed, training from scratch')
        else:
            weights_path = args.weights.format(fold)
            print('Loading weights from {}'.format(weights_path))
            model.load_weights(weights_path, by_name=True)
        freeze_model(model, args.freeze_till_layer)
        optimizer = RMSprop(lr=args.learning_rate)
        if args.optimizer:
            if args.optimizer == 'rmsprop':
                optimizer = RMSprop(lr=args.learning_rate,
                                    decay=float(args.decay))
            elif args.optimizer == 'adam':
                optimizer = Adam(lr=args.learning_rate,
                                 decay=float(args.decay))
            elif args.optimizer == 'amsgrad':
                optimizer = Adam(lr=args.learning_rate,
                                 decay=float(args.decay),
                                 amsgrad=True)
            elif args.optimizer == 'sgd':
                optimizer = SGD(lr=args.learning_rate,
                                momentum=0.9,
                                nesterov=True,
                                decay=float(args.decay))
        dataset = DSB2018BinaryDataset(args.images_dir,
                                       args.masks_dir,
                                       args.labels_dir,
                                       fold,
                                       args.n_folds,
                                       seed=args.seed)
        random_transform = aug_mega_hardcore()
        train_generator = dataset.train_generator(
            (args.crop_size, args.crop_size),
            args.preprocessing_function,
            random_transform,
            batch_size=args.batch_size)
        val_generator = dataset.val_generator(args.preprocessing_function,
                                              batch_size=1)
        best_model_file = '{}/best_{}{}_fold{}.h5'.format(
            args.models_dir, args.alias, args.network, fold)

        best_model = ModelCheckpointMGPU(model,
                                         filepath=best_model_file,
                                         monitor='val_loss',
                                         verbose=1,
                                         mode='min',
                                         period=args.save_period,
                                         save_best_only=True,
                                         save_weights_only=True)
        last_model_file = '{}/last_{}{}_fold{}.h5'.format(
            args.models_dir, args.alias, args.network, fold)

        last_model = ModelCheckpointMGPU(model,
                                         filepath=last_model_file,
                                         monitor='val_loss',
                                         verbose=1,
                                         mode='min',
                                         period=args.save_period,
                                         save_best_only=False,
                                         save_weights_only=True)
        if args.multi_gpu:
            model = multi_gpu_model(model, len(gpus))
        model.compile(
            loss=make_loss(args.loss_function),
            optimizer=optimizer,
            metrics=[binary_crossentropy, hard_dice_coef_ch1, hard_dice_coef])

        def schedule_steps(epoch, steps):
            for step in steps:
                if step[1] > epoch:
                    print("Setting learning rate to {}".format(step[0]))
                    return step[0]
            print("Setting learning rate to {}".format(steps[-1][0]))
            return steps[-1][0]

        callbacks = [best_model, last_model]

        if args.schedule is not None:
            steps = [(float(step.split(":")[0]), int(step.split(":")[1]))
                     for step in args.schedule.split(",")]
            lrSchedule = LearningRateScheduler(
                lambda epoch: schedule_steps(epoch, steps))
            callbacks.insert(0, lrSchedule)
        tb = TensorBoard("logs/{}_{}".format(args.network, fold))
        callbacks.append(tb)
        steps_per_epoch = len(dataset.train_ids) / args.batch_size + 1
        if args.steps_per_epoch > 0:
            steps_per_epoch = args.steps_per_epoch
        validation_data = val_generator
        validation_steps = len(dataset.val_ids)

        model.fit_generator(train_generator,
                            steps_per_epoch=steps_per_epoch,
                            epochs=args.epochs,
                            validation_data=validation_data,
                            validation_steps=validation_steps,
                            callbacks=callbacks,
                            max_queue_size=5,
                            verbose=1,
                            workers=args.num_workers)

        del model
        K.clear_session()
        gc.collect()
Esempio n. 16
0
def main():
    man_dir = args.manual_dataset_dir
    auto_dir = args.auto_dataset_dir
    # man_mask_dir = os.path.join(args.manual_dataset_dir, args.train_mask_dir_name)
    # man_val_mask_dir = os.path.join(args.manual_dataset_dir, args.val_mask_dir_name)
    # auto_mask_dir = os.path.join(args.manual_dataset_dir, args.train_mask_dir_name)
    # auto_val_mask_dir = os.path.join(args.manual_dataset_dir, args.val_mask_dir_name)
    #
    # man_train_data_dir = os.path.join(args.auto_dataset_dir, args.train_data_dir_name)
    # man_val_data_dir = os.path.join(args.auto_dataset_dir, args.val_data_dir_name)
    # auto_train_data_dir = os.path.join(args.auto_dataset_dir, args.train_data_dir_name)
    # auto_val_data_dir = os.path.join(args.auto_dataset_dir, args.val_data_dir_name)
    # man_mask
    config = tf.ConfigProto()
    config.gpu_options.per_process_gpu_memory_fraction = 0.1
    set_session(tf.Session(config=config))
    # mask_dir = 'data/train/masks_fail'
    # val_mask_dir = 'data/val/masks'
    #
    # train_data_dir = 'data/train/images_fail'
    # val_data_dir = 'data/val/images'

    if args.net_alias is not None:
        formatted_net_alias = '-{}-'.format(args.net_alias)

    best_model_file =\
        '{}/{}{}loss-{}-fold_{}-{}{:.6f}'.format(args.models_dir, args.network, formatted_net_alias, args.loss_function, args.fold, args.input_width, args.learning_rate) +\
        '-{epoch:d}-{val_prediction_loss:0.7f}-{val_prediction_dice_coef:0.7f}-{val_prediction_f1_score:0.7f}-{val_nadir_output_acc:0.7f}-{val_tangent_output_acc:0.7f}.h5'
    if args.edges:
        ch = 3
    else:
        ch = 3
    # model = make_model((None, None, args.stacked_channels + ch))
    model = make_model((args.input_height, args.input_width, args.stacked_channels + ch))

    freeze_model(model, args.freeze_till_layer)

    if args.weights is None:
        print('No weights passed, training from scratch')
    else:
        print('Loading weights from {}'.format(args.weights))
        model.load_weights(args.weights, by_name=True)

    optimizer = Adam(lr=args.learning_rate)

    if args.show_summary:
        model.summary()

    model.compile(loss=[make_loss(args.loss_function), 'categorical_crossentropy', 'categorical_crossentropy'],
                  optimizer=optimizer, loss_weights=[1, 0.5, 0.11],
                  metrics={'prediction': [dice_coef_border, dice_coef, binary_crossentropy, dice_coef_clipped, f1_score],
                           'nadir_output': 'accuracy',
                           'tangent_output': 'accuracy'})

    crop_size = None

    if args.use_crop:
        crop_size = (args.input_height, args.input_width)
        print('Using crops of shape ({}, {})'.format(args.input_height, args.input_width))
    else:
        print('Using full size images, --use_crop=True to do crops')

    train_df = pd.read_csv(args.train_df)
    val_df = pd.read_csv(args.val_df)
    # folds_df = pd.read_csv(os.path.join(args.dataset_dir, args.folds_source))
    # train_ids = generate_filenames(folds_df[folds_df.fold != args.fold]['id'])
    # val_ids = generate_filenames(folds_df[folds_df.fold == args.fold]['id'])
    # train_ids = os.listdir(train_data_dir)
    # val_ids = os.listdir(val_data_dir)

    print('Training fold #{}, {} in train_ids, {} in val_ids'.format(args.fold, len(train_df), len(val_df)))

    train_generator = build_batch_generator(
        train_df,
        img_man_dir=man_dir,
        img_auto_dir=auto_dir,
        batch_size=args.batch_size,
        shuffle=True,
        out_size=(args.out_height, args.out_width),
        crop_size=crop_size,
        # mask_dir=mask_dir,
        aug=True
    )

    val_generator = build_batch_generator(
        val_df,
        img_man_dir=man_dir,
        img_auto_dir=auto_dir,
        batch_size=args.batch_size,
        shuffle=False,
        out_size=(args.out_height, args.out_width),
        crop_size=crop_size,
        # mask_dir=val_mask_dir,
        aug=False
    )

    best_model = ModelCheckpoint(best_model_file, monitor='val_prediction_dice_coef',
                                                  verbose=1,
                                                  save_best_only=False,
                                                  save_weights_only=True,
                                                  mode='max')

    callbacks = [best_model,
                 # EarlyStopping(patience=45, verbose=10),
                 TensorBoard(log_dir='./logs', histogram_freq=0, write_graph=True, write_images=True),
                 ]
                 # ReduceLROnPlateau(monitor='val_prediction_dice_coef', mode='max', factor=0.2, patience=5, min_lr=0.00001,
                 #                   verbose=1)]
    if args.clr is not None:
        clr_params = args.clr.split(',')
        base_lr = float(clr_params[0])
        max_lr = float(clr_params[1])
        step = int(clr_params[2])
        mode = clr_params[3]
        clr = CyclicLR(base_lr=base_lr, max_lr=max_lr, step_size=step, mode=mode)
        callbacks.append(clr)
    model.fit_generator(
        ThreadsafeIter(train_generator),
        steps_per_epoch=len(train_df) / args.batch_size + 1,
        epochs=args.epochs,
        validation_data=ThreadsafeIter(val_generator),
        validation_steps=len(val_df) / args.batch_size + 1,
        callbacks=callbacks,
        max_queue_size=50,
        workers=4)
def main():
    args = parse_args()
    config_path = args.config_file_path

    config = get_config(config_path, new_keys_allowed=True)

    config.defrost()
    config.experiment_dir = os.path.join(config.log_dir, config.experiment_name)
    config.tb_dir = os.path.join(config.experiment_dir, 'tb')
    config.model.best_checkpoint_path = os.path.join(config.experiment_dir, 'best_checkpoint.pt')
    config.model.last_checkpoint_path = os.path.join(config.experiment_dir, 'last_checkpoint.pt')
    config.config_save_path = os.path.join(config.experiment_dir, 'segmentation_config.yaml')
    config.freeze()

    init_experiment(config)
    set_random_seed(config.seed)

    train_dataset = make_dataset(config.train.dataset)
    train_loader = make_data_loader(config.train.loader, train_dataset)

    val_dataset = make_dataset(config.val.dataset)
    val_loader = make_data_loader(config.val.loader, val_dataset)

    device = torch.device(config.device)
    model = make_model(config.model).to(device)

    optimizer = make_optimizer(config.optim, model.parameters())
    scheduler = None

    loss_f = make_loss(config.loss)

    early_stopping = EarlyStopping(
        **config.stopper.params
    )

    train_writer = SummaryWriter(log_dir=os.path.join(config.tb_dir, 'train'))
    val_writer = SummaryWriter(log_dir=os.path.join(config.tb_dir, 'val'))

    for epoch in range(1, config.epochs + 1):
        print(f'Epoch {epoch}')
        train_metrics = train(model, optimizer, train_loader, loss_f, device)
        write_metrics(epoch, train_metrics, train_writer)
        print_metrics('Train', train_metrics)

        val_metrics = val(model, val_loader, loss_f, device)
        write_metrics(epoch, val_metrics, val_writer)
        print_metrics('Val', val_metrics)

        early_stopping(val_metrics['loss'])
        if config.model.save and early_stopping.counter == 0:
            torch.save(model.state_dict(), config.model.best_checkpoint_path)
            print('Saved best model checkpoint to disk.')
        if early_stopping.early_stop:
            print(f'Early stopping after {epoch} epochs.')
            break

        if scheduler:
            scheduler.step()

    train_writer.close()
    val_writer.close()

    if config.model.save:
        torch.save(model.state_dict(), config.model.last_checkpoint_path)
        print('Saved last model checkpoint to disk.')
def get_models(img_rows, img_cols, fold):
    model, process = get_model(network=network, input_shape=(img_rows, img_cols, 1),
                               freeze_encoder=False)
    model.load_weights(weights_path + str(fold) + '.hdf5')
    model.compile(optimizer=Adam(lr=learning_rate), loss=make_loss(loss_name=loss_function), metrics=[dice_coef])
    return model, process