def post(self): ''' Train seg model. ''' if request.headers['Content-Type'] != 'application/json': msg = 'Content-Type should be application/json' logger().error(msg) return make_response(jsonify(error=msg), 400) try: req_json = request.get_json() req_json['device'] = get_device() config = SegmentationTrainConfig(req_json) logger().info(config) if config.n_crops < 1: msg = 'n_crops should be a positive number' logger().error(msg) return make_response(jsonify(error=msg), 400) spots_dict = collections.defaultdict(list) for spot in req_json.get('spots'): spots_dict[spot['t']].append(spot) if not (spots_dict or config.is_livemode): msg = 'nothing to train' logger().error(msg) return make_response(jsonify(error=msg), 400) spots_dict = collections.OrderedDict(sorted(spots_dict.items())) if get_state() != TrainState.IDLE.value: msg = 'Process is running' logger().error(msg) return make_response(jsonify(error=msg), 500) redis_client.set(REDIS_KEY_STATE, TrainState.RUN.value) if config.is_livemode: redis_client.delete(REDIS_KEY_TIMEPOINT) else: try: _update_seg_labels(spots_dict, config.scales, config.zpath_input, config.zpath_seg_label, config.zpath_seg_label_vis, config.auto_bg_thresh, config.c_ratio, memmap_dir=config.memmap_dir) except KeyboardInterrupt: return make_response(jsonify({'completed': False})) step_offset = 0 for path in sorted(Path(config.log_dir).glob('event*')): try: *_, last_record = TFRecordDataset(str(path)) last = event_pb2.Event.FromString(last_record.numpy()).step step_offset = max(step_offset, last + 1) except Exception: pass epoch_start = 0 async_result = train_seg_task.delay( list(spots_dict.keys()), config.batch_size, config.crop_size, config.class_weights, config.false_weight, config.model_path, config.n_epochs, config.keep_axials, config.scales, config.lr, config.n_crops, config.is_3d, config.is_livemode, config.scale_factor_base, config.rotation_angle, config.contrast, config.zpath_input, config.zpath_seg_label, config.log_interval, config.log_dir, step_offset, epoch_start, config.is_cpu(), config.is_mixed_precision, config.cache_maxbytes, config.memmap_dir, config.input_size, ) while not async_result.ready(): if (redis_client is not None and get_state() == TrainState.IDLE.value): logger().info('training aborted') return make_response(jsonify({'completed': False})) except Exception as e: logger().exception('Failed in train_seg') return make_response(jsonify(error=f'Exception: {e}'), 500) finally: torch.cuda.empty_cache() redis_client.set(REDIS_KEY_STATE, TrainState.IDLE.value) return make_response(jsonify({'completed': True}))
def main(): parser = argparse.ArgumentParser() parser.add_argument('command', help='seg | flow') parser.add_argument('config', help='config file') parser.add_argument('--baseconfig', help='base config file') args = parser.parse_args() if args.command not in ['seg', 'flow']: print('command option should be "seg" or "flow"') parser.print_help() exit(1) base_config_dict = dict() if args.baseconfig is not None: with io.open(args.baseconfig, 'r', encoding='utf-8') as jsonfile: base_config_dict.update(json.load(jsonfile)) with io.open(args.config, 'r', encoding='utf-8') as jsonfile: config_data = json.load(jsonfile) # load or initialize models os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '12355' # prepare dataset train_datasets = [] eval_datasets = [] for i in range(len(config_data)): config_dict = base_config_dict.copy() config_dict.update(config_data[i]) config = ResetConfig(config_dict) n_dims = 2 + config.is_3d # 3 or 2 models = load_models(config, args.command) za_input = zarr.open(config.zpath_input, mode='r') input_size = config_dict.get('input_size', za_input.shape[-n_dims:]) train_index = [] eval_index = [] eval_interval = config_dict.get('evalinterval') t_min = config_dict.get('t_min', 0) t_max = config_dict.get('t_max', za_input.shape[0] - 1) train_length = config_dict.get('train_length') eval_length = config_dict.get('eval_length') adaptive_length = config_dict.get('adaptive_length', False) if args.command == 'seg': config = SegmentationTrainConfig(config_dict) print(config) za_label = zarr.open(config.zpath_seg_label, mode='r') for ti, t in enumerate(range(t_min, t_max + 1)): if 0 < za_label[t].max(): if eval_interval is not None and eval_interval == -1: train_index.append(t) eval_index.append(t) elif (eval_interval is not None and (ti + 1) % eval_interval == 0): eval_index.append(t) else: train_index.append(t) train_datasets.append( SegmentationDatasetZarr( config.zpath_input, config.zpath_seg_label, train_index, input_size, config.crop_size, config.n_crops, keep_axials=config.keep_axials, scales=config.scales, scale_factor_base=config.scale_factor_base, contrast=config.contrast, rotation_angle=config.rotation_angle, length=train_length, cache_maxbytes=config.cache_maxbytes, memmap_dir=config.memmap_dir, )) eval_datasets.append( SegmentationDatasetZarr( config.zpath_input, config.zpath_seg_label, eval_index, input_size, input_size, 1, keep_axials=config.keep_axials, is_eval=True, length=eval_length, adaptive_length=adaptive_length, cache_maxbytes=config.cache_maxbytes, memmap_dir=config.memmap_dir, )) elif args.command == 'flow': config = FlowTrainConfig(config_dict) print(config) za_label = zarr.open(config.zpath_flow_label, mode='r') for ti, t in enumerate(range(t_min, t_max)): if 0 < za_label[t][-1].max(): if eval_interval is not None and eval_interval == -1: train_index.append(t) eval_index.append(t) elif (eval_interval is not None and (ti + 1) % eval_interval == 0): eval_index.append(t) else: train_index.append(t) train_datasets.append( FlowDatasetZarr( config.zpath_input, config.zpath_flow_label, train_index, input_size, config.crop_size, config.n_crops, keep_axials=config.keep_axials, scales=config.scales, scale_factor_base=config.scale_factor_base, rotation_angle=config.rotation_angle, length=train_length, cache_maxbytes=config.cache_maxbytes, memmap_dir=config.memmap_dir, )) eval_datasets.append( FlowDatasetZarr( config.zpath_input, config.zpath_flow_label, eval_index, input_size, input_size, 1, keep_axials=config.keep_axials, is_eval=True, length=eval_length, adaptive_length=adaptive_length, cache_maxbytes=config.cache_maxbytes, memmap_dir=config.memmap_dir, )) train_dataset = ConcatDataset(train_datasets) eval_dataset = ConcatDataset(eval_datasets) if 0 < len(train_dataset): if args.command == 'seg': weight_tensor = torch.tensor(config.class_weights) loss_fn = SegmentationLoss(class_weights=weight_tensor, false_weight=config.false_weight, is_3d=config.is_3d) elif args.command == 'flow': loss_fn = FlowLoss(is_3d=config.is_3d) optimizers = [ torch.optim.Adam(model.parameters(), lr=config.lr) for model in models ] train_loader = DataLoader(train_dataset, shuffle=True, batch_size=config.batch_size) eval_loader = DataLoader(eval_dataset, shuffle=False, batch_size=config.batch_size) step_offset = 0 for path in sorted(Path(config.log_dir).glob('event*')): try: *_, last_record = TFRecordDataset(str(path)) last = event_pb2.Event.FromString(last_record.numpy()).step step_offset = max(step_offset, last + 1) except Exception: pass if PROFILE: run_train(config.device, 1, models, train_loader, optimizers, loss_fn, config.n_epochs, config.model_path, False, config.log_interval, config.log_dir, step_offset, config.epoch_start, eval_loader, config.patch_size, config.is_cpu(), args.command == 'seg') else: world_size = (2 if config.is_cpu() else torch.cuda.device_count()) mp.spawn(run_train, args=(world_size, models, train_loader, optimizers, loss_fn, config.n_epochs, config.model_path, False, config.log_interval, config.log_dir, step_offset, config.epoch_start, eval_loader, config.patch_size, config.is_cpu(), args.command == 'seg'), nprocs=world_size, join=True)