예제 #1
0
    def post(self):
        '''
        Train seg model.

        '''
        if request.headers['Content-Type'] != 'application/json':
            msg = 'Content-Type should be application/json'
            logger().error(msg)
            return make_response(jsonify(error=msg), 400)
        try:
            req_json = request.get_json()
            req_json['device'] = get_device()
            config = SegmentationTrainConfig(req_json)
            logger().info(config)
            if config.n_crops < 1:
                msg = 'n_crops should be a positive number'
                logger().error(msg)
                return make_response(jsonify(error=msg), 400)

            spots_dict = collections.defaultdict(list)
            for spot in req_json.get('spots'):
                spots_dict[spot['t']].append(spot)
            if not (spots_dict or config.is_livemode):
                msg = 'nothing to train'
                logger().error(msg)
                return make_response(jsonify(error=msg), 400)
            spots_dict = collections.OrderedDict(sorted(spots_dict.items()))

            if get_state() != TrainState.IDLE.value:
                msg = 'Process is running'
                logger().error(msg)
                return make_response(jsonify(error=msg), 500)
            redis_client.set(REDIS_KEY_STATE, TrainState.RUN.value)
            if config.is_livemode:
                redis_client.delete(REDIS_KEY_TIMEPOINT)
            else:
                try:
                    _update_seg_labels(spots_dict,
                                       config.scales,
                                       config.zpath_input,
                                       config.zpath_seg_label,
                                       config.zpath_seg_label_vis,
                                       config.auto_bg_thresh,
                                       config.c_ratio,
                                       memmap_dir=config.memmap_dir)
                except KeyboardInterrupt:
                    return make_response(jsonify({'completed': False}))
            step_offset = 0
            for path in sorted(Path(config.log_dir).glob('event*')):
                try:
                    *_, last_record = TFRecordDataset(str(path))
                    last = event_pb2.Event.FromString(last_record.numpy()).step
                    step_offset = max(step_offset, last + 1)
                except Exception:
                    pass
            epoch_start = 0
            async_result = train_seg_task.delay(
                list(spots_dict.keys()),
                config.batch_size,
                config.crop_size,
                config.class_weights,
                config.false_weight,
                config.model_path,
                config.n_epochs,
                config.keep_axials,
                config.scales,
                config.lr,
                config.n_crops,
                config.is_3d,
                config.is_livemode,
                config.scale_factor_base,
                config.rotation_angle,
                config.contrast,
                config.zpath_input,
                config.zpath_seg_label,
                config.log_interval,
                config.log_dir,
                step_offset,
                epoch_start,
                config.is_cpu(),
                config.is_mixed_precision,
                config.cache_maxbytes,
                config.memmap_dir,
                config.input_size,
            )
            while not async_result.ready():
                if (redis_client is not None
                        and get_state() == TrainState.IDLE.value):
                    logger().info('training aborted')
                    return make_response(jsonify({'completed': False}))
        except Exception as e:
            logger().exception('Failed in train_seg')
            return make_response(jsonify(error=f'Exception: {e}'), 500)
        finally:
            torch.cuda.empty_cache()
            redis_client.set(REDIS_KEY_STATE, TrainState.IDLE.value)
        return make_response(jsonify({'completed': True}))
예제 #2
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('command', help='seg | flow')
    parser.add_argument('config', help='config file')
    parser.add_argument('--baseconfig', help='base config file')
    args = parser.parse_args()
    if args.command not in ['seg', 'flow']:
        print('command option should be "seg" or "flow"')
        parser.print_help()
        exit(1)
    base_config_dict = dict()
    if args.baseconfig is not None:
        with io.open(args.baseconfig, 'r', encoding='utf-8') as jsonfile:
            base_config_dict.update(json.load(jsonfile))
    with io.open(args.config, 'r', encoding='utf-8') as jsonfile:
        config_data = json.load(jsonfile)
    # load or initialize models

    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # prepare dataset
    train_datasets = []
    eval_datasets = []
    for i in range(len(config_data)):
        config_dict = base_config_dict.copy()
        config_dict.update(config_data[i])
        config = ResetConfig(config_dict)
        n_dims = 2 + config.is_3d  # 3 or 2
        models = load_models(config, args.command)
        za_input = zarr.open(config.zpath_input, mode='r')
        input_size = config_dict.get('input_size', za_input.shape[-n_dims:])
        train_index = []
        eval_index = []
        eval_interval = config_dict.get('evalinterval')
        t_min = config_dict.get('t_min', 0)
        t_max = config_dict.get('t_max', za_input.shape[0] - 1)
        train_length = config_dict.get('train_length')
        eval_length = config_dict.get('eval_length')
        adaptive_length = config_dict.get('adaptive_length', False)
        if args.command == 'seg':
            config = SegmentationTrainConfig(config_dict)
            print(config)
            za_label = zarr.open(config.zpath_seg_label, mode='r')
            for ti, t in enumerate(range(t_min, t_max + 1)):
                if 0 < za_label[t].max():
                    if eval_interval is not None and eval_interval == -1:
                        train_index.append(t)
                        eval_index.append(t)
                    elif (eval_interval is not None
                          and (ti + 1) % eval_interval == 0):
                        eval_index.append(t)
                    else:
                        train_index.append(t)
            train_datasets.append(
                SegmentationDatasetZarr(
                    config.zpath_input,
                    config.zpath_seg_label,
                    train_index,
                    input_size,
                    config.crop_size,
                    config.n_crops,
                    keep_axials=config.keep_axials,
                    scales=config.scales,
                    scale_factor_base=config.scale_factor_base,
                    contrast=config.contrast,
                    rotation_angle=config.rotation_angle,
                    length=train_length,
                    cache_maxbytes=config.cache_maxbytes,
                    memmap_dir=config.memmap_dir,
                ))
            eval_datasets.append(
                SegmentationDatasetZarr(
                    config.zpath_input,
                    config.zpath_seg_label,
                    eval_index,
                    input_size,
                    input_size,
                    1,
                    keep_axials=config.keep_axials,
                    is_eval=True,
                    length=eval_length,
                    adaptive_length=adaptive_length,
                    cache_maxbytes=config.cache_maxbytes,
                    memmap_dir=config.memmap_dir,
                ))
        elif args.command == 'flow':
            config = FlowTrainConfig(config_dict)
            print(config)
            za_label = zarr.open(config.zpath_flow_label, mode='r')
            for ti, t in enumerate(range(t_min, t_max)):
                if 0 < za_label[t][-1].max():
                    if eval_interval is not None and eval_interval == -1:
                        train_index.append(t)
                        eval_index.append(t)
                    elif (eval_interval is not None
                          and (ti + 1) % eval_interval == 0):
                        eval_index.append(t)
                    else:
                        train_index.append(t)
            train_datasets.append(
                FlowDatasetZarr(
                    config.zpath_input,
                    config.zpath_flow_label,
                    train_index,
                    input_size,
                    config.crop_size,
                    config.n_crops,
                    keep_axials=config.keep_axials,
                    scales=config.scales,
                    scale_factor_base=config.scale_factor_base,
                    rotation_angle=config.rotation_angle,
                    length=train_length,
                    cache_maxbytes=config.cache_maxbytes,
                    memmap_dir=config.memmap_dir,
                ))
            eval_datasets.append(
                FlowDatasetZarr(
                    config.zpath_input,
                    config.zpath_flow_label,
                    eval_index,
                    input_size,
                    input_size,
                    1,
                    keep_axials=config.keep_axials,
                    is_eval=True,
                    length=eval_length,
                    adaptive_length=adaptive_length,
                    cache_maxbytes=config.cache_maxbytes,
                    memmap_dir=config.memmap_dir,
                ))
    train_dataset = ConcatDataset(train_datasets)
    eval_dataset = ConcatDataset(eval_datasets)
    if 0 < len(train_dataset):
        if args.command == 'seg':
            weight_tensor = torch.tensor(config.class_weights)
            loss_fn = SegmentationLoss(class_weights=weight_tensor,
                                       false_weight=config.false_weight,
                                       is_3d=config.is_3d)
        elif args.command == 'flow':
            loss_fn = FlowLoss(is_3d=config.is_3d)
        optimizers = [
            torch.optim.Adam(model.parameters(), lr=config.lr)
            for model in models
        ]
        train_loader = DataLoader(train_dataset,
                                  shuffle=True,
                                  batch_size=config.batch_size)
        eval_loader = DataLoader(eval_dataset,
                                 shuffle=False,
                                 batch_size=config.batch_size)
        step_offset = 0
        for path in sorted(Path(config.log_dir).glob('event*')):
            try:
                *_, last_record = TFRecordDataset(str(path))
                last = event_pb2.Event.FromString(last_record.numpy()).step
                step_offset = max(step_offset, last + 1)
            except Exception:
                pass
        if PROFILE:
            run_train(config.device, 1, models, train_loader, optimizers,
                      loss_fn, config.n_epochs, config.model_path, False,
                      config.log_interval, config.log_dir, step_offset,
                      config.epoch_start, eval_loader, config.patch_size,
                      config.is_cpu(), args.command == 'seg')
        else:
            world_size = (2 if config.is_cpu() else torch.cuda.device_count())
            mp.spawn(run_train,
                     args=(world_size, models, train_loader, optimizers,
                           loss_fn, config.n_epochs, config.model_path, False,
                           config.log_interval, config.log_dir, step_offset,
                           config.epoch_start, eval_loader, config.patch_size,
                           config.is_cpu(), args.command == 'seg'),
                     nprocs=world_size,
                     join=True)