Example #1
0
def test_net(data_dir, cross_valid_ind=1, cfg=None):

    if 'dataset' in cfg and cfg['dataset'] == "Cell_nuclei":
        valid_dataset = create_cell_nuclei_dataset(
            data_dir,
            cfg['img_size'],
            1,
            1,
            is_train=False,
            eval_resize=cfg["eval_resize"],
            split=0.8)
    else:
        _, valid_dataset = create_dataset(data_dir,
                                          1,
                                          1,
                                          False,
                                          cross_valid_ind,
                                          False,
                                          do_crop=cfg['crop'],
                                          img_size=cfg['img_size'])
    labels_list = []

    for data in valid_dataset:
        labels_list.append(data[1].asnumpy())

    return labels_list
Example #2
0
def test_net(data_dir, ckpt_path, cross_valid_ind=1, cfg=None):
    if cfg['model'] == 'unet_medical':
        net = UNetMedical(n_channels=cfg['num_channels'],
                          n_classes=cfg['num_classes'])
    elif cfg['model'] == 'unet_nested':
        net = NestedUNet(in_channel=cfg['num_channels'],
                         n_class=cfg['num_classes'],
                         use_deconv=cfg['use_deconv'],
                         use_bn=cfg['use_bn'],
                         use_ds=False)
    elif cfg['model'] == 'unet_simple':
        net = UNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'])
    else:
        raise ValueError("Unsupported model: {}".format(cfg['model']))
    param_dict = load_checkpoint(ckpt_path)
    load_param_into_net(net, param_dict)
    net = UnetEval(net)
    if 'dataset' in cfg and cfg['dataset'] == "Cell_nuclei":
        valid_dataset = create_cell_nuclei_dataset(
            data_dir,
            cfg['img_size'],
            1,
            1,
            is_train=False,
            eval_resize=cfg["eval_resize"],
            split=0.8)
    else:
        _, valid_dataset = create_dataset(data_dir,
                                          1,
                                          1,
                                          False,
                                          cross_valid_ind,
                                          False,
                                          do_crop=cfg['crop'],
                                          img_size=cfg['img_size'])
    model = Model(net,
                  loss_fn=TempLoss(),
                  metrics={"dice_coeff": dice_coeff()})

    print("============== Starting Evaluating ============")
    eval_score = model.eval(valid_dataset,
                            dataset_sink_mode=False)["dice_coeff"]
    print("============== Cross valid dice coeff is:", eval_score[0])
    print("============== Cross valid IOU is:", eval_score[1])
Example #3
0
def train_net(args_opt,
              cross_valid_ind=1,
              epochs=400,
              batch_size=16,
              lr=0.0001,
              cfg=None):
    rank = 0
    group_size = 1
    data_dir = args_opt.data_url
    run_distribute = args_opt.run_distribute
    if run_distribute:
        init()
        group_size = get_group_size()
        rank = get_rank()
        parallel_mode = ParallelMode.DATA_PARALLEL
        context.set_auto_parallel_context(parallel_mode=parallel_mode,
                                          device_num=group_size,
                                          gradients_mean=False)
    need_slice = False
    if cfg['model'] == 'unet_medical':
        net = UNetMedical(n_channels=cfg['num_channels'],
                          n_classes=cfg['num_classes'])
    elif cfg['model'] == 'unet_nested':
        net = NestedUNet(in_channel=cfg['num_channels'],
                         n_class=cfg['num_classes'],
                         use_deconv=cfg['use_deconv'],
                         use_bn=cfg['use_bn'],
                         use_ds=cfg['use_ds'])
        need_slice = cfg['use_ds']
    elif cfg['model'] == 'unet_simple':
        net = UNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'])
    else:
        raise ValueError("Unsupported model: {}".format(cfg['model']))

    if cfg['resume']:
        param_dict = load_checkpoint(cfg['resume_ckpt'])
        if cfg['transfer_training']:
            filter_checkpoint_parameter_by_list(param_dict,
                                                cfg['filter_weight'])
        load_param_into_net(net, param_dict)

    if 'use_ds' in cfg and cfg['use_ds']:
        criterion = MultiCrossEntropyWithLogits()
    else:
        criterion = CrossEntropyWithLogits()
    if 'dataset' in cfg and cfg['dataset'] == "Cell_nuclei":
        repeat = cfg['repeat']
        dataset_sink_mode = True
        per_print_times = 0
        train_dataset = create_cell_nuclei_dataset(data_dir,
                                                   cfg['img_size'],
                                                   repeat,
                                                   batch_size,
                                                   is_train=True,
                                                   augment=True,
                                                   split=0.8,
                                                   rank=rank,
                                                   group_size=group_size)
        valid_dataset = create_cell_nuclei_dataset(
            data_dir,
            cfg['img_size'],
            1,
            1,
            is_train=False,
            eval_resize=cfg["eval_resize"],
            split=0.8,
            python_multiprocessing=False)
    else:
        repeat = cfg['repeat']
        dataset_sink_mode = False
        per_print_times = 1
        train_dataset, valid_dataset = create_dataset(
            data_dir, repeat, batch_size, True, cross_valid_ind,
            run_distribute, cfg["crop"], cfg['img_size'])
    train_data_size = train_dataset.get_dataset_size()
    print("dataset length is:", train_data_size)
    ckpt_config = CheckpointConfig(
        save_checkpoint_steps=train_data_size,
        keep_checkpoint_max=cfg['keep_checkpoint_max'])
    ckpoint_cb = ModelCheckpoint(prefix='ckpt_{}_adam'.format(cfg['model']),
                                 directory='./ckpt_{}/'.format(device_id),
                                 config=ckpt_config)

    optimizer = nn.Adam(params=net.trainable_params(),
                        learning_rate=lr,
                        weight_decay=cfg['weight_decay'],
                        loss_scale=cfg['loss_scale'])

    loss_scale_manager = mindspore.train.loss_scale_manager.FixedLossScaleManager(
        cfg['FixedLossScaleManager'], False)

    model = Model(net,
                  loss_fn=criterion,
                  loss_scale_manager=loss_scale_manager,
                  optimizer=optimizer,
                  amp_level="O3")

    print("============== Starting Training ==============")
    callbacks = [
        StepLossTimeMonitor(batch_size=batch_size,
                            per_print_times=per_print_times), ckpoint_cb
    ]
    if args_opt.run_eval:
        eval_model = Model(UnetEval(net, need_slice=need_slice),
                           loss_fn=TempLoss(),
                           metrics={"dice_coeff": dice_coeff(cfg_unet, False)})
        eval_param_dict = {
            "model": eval_model,
            "dataset": valid_dataset,
            "metrics_name": args_opt.eval_metrics
        }
        eval_cb = EvalCallBack(apply_eval,
                               eval_param_dict,
                               interval=args_opt.eval_interval,
                               eval_start_epoch=args_opt.eval_start_epoch,
                               save_best_ckpt=True,
                               ckpt_directory='./ckpt_{}/'.format(device_id),
                               besk_ckpt_name="best.ckpt",
                               metrics_name=args_opt.eval_metrics)
        callbacks.append(eval_cb)
    model.train(int(epochs / repeat),
                train_dataset,
                callbacks=callbacks,
                dataset_sink_mode=dataset_sink_mode)
    print("============== End Training ==============")
Example #4
0
def train_net(data_dir,
              cross_valid_ind=1,
              epochs=400,
              batch_size=16,
              lr=0.0001,
              run_distribute=False,
              cfg=None):
    rank = 0
    group_size = 1
    if run_distribute:
        init()
        group_size = get_group_size()
        rank = get_rank()
        parallel_mode = ParallelMode.DATA_PARALLEL
        context.set_auto_parallel_context(parallel_mode=parallel_mode,
                                          device_num=group_size,
                                          gradients_mean=False)

    if cfg['model'] == 'unet_medical':
        net = UNetMedical(n_channels=cfg['num_channels'],
                          n_classes=cfg['num_classes'])
    elif cfg['model'] == 'unet_nested':
        net = NestedUNet(in_channel=cfg['num_channels'],
                         n_class=cfg['num_classes'],
                         use_deconv=cfg['use_deconv'],
                         use_bn=cfg['use_bn'],
                         use_ds=cfg['use_ds'])
    elif cfg['model'] == 'unet_simple':
        net = UNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'])
    else:
        raise ValueError("Unsupported model: {}".format(cfg['model']))

    if cfg['resume']:
        param_dict = load_checkpoint(cfg['resume_ckpt'])
        if cfg['transfer_training']:
            filter_checkpoint_parameter_by_list(param_dict,
                                                cfg['filter_weight'])
        load_param_into_net(net, param_dict)

    if 'use_ds' in cfg and cfg['use_ds']:
        criterion = MultiCrossEntropyWithLogits()
    else:
        criterion = CrossEntropyWithLogits()
    if 'dataset' in cfg and cfg['dataset'] == "Cell_nuclei":
        repeat = 10
        dataset_sink_mode = True
        per_print_times = 0
        train_dataset = create_cell_nuclei_dataset(data_dir,
                                                   cfg['img_size'],
                                                   repeat,
                                                   batch_size,
                                                   is_train=True,
                                                   augment=True,
                                                   split=0.8,
                                                   rank=rank,
                                                   group_size=group_size)
    else:
        repeat = epochs
        dataset_sink_mode = False
        per_print_times = 1
        train_dataset, _ = create_dataset(data_dir, repeat, batch_size, True,
                                          cross_valid_ind, run_distribute,
                                          cfg["crop"], cfg['img_size'])
    train_data_size = train_dataset.get_dataset_size()
    print("dataset length is:", train_data_size)
    ckpt_config = CheckpointConfig(
        save_checkpoint_steps=train_data_size,
        keep_checkpoint_max=cfg['keep_checkpoint_max'])
    ckpoint_cb = ModelCheckpoint(prefix='ckpt_{}_adam'.format(cfg['model']),
                                 directory='./ckpt_{}/'.format(device_id),
                                 config=ckpt_config)

    optimizer = nn.Adam(params=net.trainable_params(),
                        learning_rate=lr,
                        weight_decay=cfg['weight_decay'],
                        loss_scale=cfg['loss_scale'])

    loss_scale_manager = mindspore.train.loss_scale_manager.FixedLossScaleManager(
        cfg['FixedLossScaleManager'], False)

    model = Model(net,
                  loss_fn=criterion,
                  loss_scale_manager=loss_scale_manager,
                  optimizer=optimizer,
                  amp_level="O3")

    print("============== Starting Training ==============")
    callbacks = [
        StepLossTimeMonitor(batch_size=batch_size,
                            per_print_times=per_print_times), ckpoint_cb
    ]
    model.train(int(epochs / repeat),
                train_dataset,
                callbacks=callbacks,
                dataset_sink_mode=dataset_sink_mode)
    print("============== End Training ==============")