async def update_stock_info(info_file, info, create=True, verbose=VERBOSE): try: # Clean key names clean_info = clean_enumeration(info) clean_info.pop('matchScore', None) # Read previous info if info_file.exists(): old_info = await read_info_file(info_file, check=False, verbose=verbose) else: old_info = {} await save_stock_info(info_file, clean_info, old_info=old_info, create=create) if verbose > 1: symbol = info_file.parent.name LOG.info(f"Updating {symbol} info:{get_tabs(symbol, prev=15)}OK") except Exception as err: LOG.error( f"ERROR updating info: {info_file}. Msg: {err.__repr__()} {traceback.print_tb(err.__traceback__)}" )
def _do_calculation(self): # prepare data for calculation y_true = self.cache_y_true[:self.idx_y_true] y_pred = self.cache_y_pred[:self.idx_y_pred] avg = self.idx_y_true // self.n_workers # multiprocessing for auc/ap scores LOG.info("waiting for calculating auc/ap scores...") with Pool(processes=self.n_workers) as pool: scores = pool.starmap(self._calculation_job, [ (y_true[avg * i:avg * (i + 1)], y_pred[avg * i:avg * (i + 1)]) if i != self.n_workers - 1 else (y_true[avg * i:], y_pred[avg * i:]) for i in range(self.n_workers) ]) scores = np.array(scores, dtype=np.float32) # apply results auc_col, auc_row, ap_col, ap_row = scores.mean(axis=0) self.running_auc_col += auc_col self.running_auc_row += auc_row self.running_ap_col += ap_col self.running_ap_row += ap_row # count step self.counter += 1 # reset cache and idx pointer self.reset_cache() return
def _get_activation(name: str) -> Optional[nn.Module]: if name.lower() == 'relu': return nn.ReLU() # TODO: test other activations else: LOG.error(f'Unrecognized activation function: [bold red]{name}[/].') raise ValueError(f'Unrecognized activation function: {name}.')
def find_data(ref, db): idx = ref.parent.name data = [entry for entry in db if entry["symbol"] == idx] if len(data) > 0: return data[0] else: LOG.warning(f"WARNING: Reference {idx} not found") return {}
def read_pandas_data(file_name): if not file_name.exists(): LOG.error(f"ERROR: data not found for {file_name}") return None return pd.read_csv(file_name, parse_dates=['date'], index_col='date', date_parser=dateparse)
def clean_pandas_data(dat): """Receives a dictionary of data, transform the dict into a pandas DataFrame and clean the column names""" try: data = pd.DataFrame.from_dict(dat, orient="index") # Apply clean names to columns and index column_names = clean_enumeration(data.columns.tolist()) data.columns = column_names data.index.name = 'date' data.sort_index(axis=0, inplace=True, ascending=True) # Sort by date except Exception as err: LOG.error(f"Error cleaning dataset: {err}") data = None return data
def manage_vantage_errors(response, symbol): if "Error Message" in response.keys(): LOG.error( f"ERROR: Not possible to retrieve {symbol}. Msg: {response['Error Message']}" ) elif "Note" in response.keys(): if response["Note"][:111] == 'Thank you for using Alpha Vantage! Our standard API call frequency ' \ 'is 5 calls per minute and 500 calls per day.': LOG.info( f"Retrieving {symbol}:{get_tabs(symbol, prev=12)}Max frequency reached! Waiting..." ) return "longWait" return None
def get_optimizer(params: Iterable, args: argparse.Namespace) -> Optional[optim.Optimizer]: name = args.optim_type.lower() if name == 'sgd': LOG.info( f"SGD Optimizer <lr={args.lr}, momentum={args.momentum}, nesterov=True>" ) return optim.SGD(params=params, lr=args.lr, momentum=args.momentum, nesterov=True) elif name == 'adam': LOG.info( f"Adam Optimizer <lr={args.lr}, betas={args.betas}, eps={args.eps}>" ) return optim.Adam(params=params, lr=args.lr, betas=args.betas, eps=args.eps) elif name == 'adamw': LOG.info( f"Adam Optimizer <lr={args.lr}, betas={args.betas}, eps={args.eps}, weight_decay={args.weight_decay}>" ) return optim.AdamW(params=params, lr=args.lr, betas=args.betas, eps=args.eps, weight_decay=args.weight_decay) else: LOG.error(f"Unsupported optimizer: [bold red]{name}[/].") raise ValueError(f"Unsupported optimizer: {name}.")
def main(args): CONSOLE.rule("Sample CNN Main Script") CONSOLE.print(args) # determine mode LOG.info(f"Mode selected: {args.mode}") if args.mode == 'train': train_on_model(args) elif args.mode == 'test': test_on_model(args) elif args.mode == 'eval': eval_on_model(args) else: raise ValueError CONSOLE.rule("Task finished") return
def save_pandas_data(file_name, dat, old_data=None, verbose=VERBOSE): try: data = clean_pandas_data(dat) if old_data is not None: try: # Avoid the last index as it may contain an incomplete week or month last_dt = old_data.index[-2] idx = data.index.get_loc(last_dt.strftime("%Y-%m-%d")) updated_data = pd.concat( (old_data.iloc[:-2, :], data.iloc[idx:, :]), axis=0) updated_data.reset_index().to_csv( file_name, index=False, compression="infer") # Update except KeyError as err: LOG.error(f"Error updating the data: {err}") else: data.reset_index().to_csv(file_name, index=False, compression="infer") # Save if verbose > 1: symbol = file_name.parent.name LOG.info( f"Saved {symbol} data:{get_tabs(symbol, prev=12)}[{file_name.stem}] OK" ) except Exception as err: LOG.error( f"ERROR saving data:\t\t{file_name.parent.name + file_name.stem} " f"{err.__repr__()} {traceback.print_tb(err.__traceback__)}")
def MTT_statistics(args): from src.data.dataset import MTTDataset from rich.progress import Progress, BarColumn, TimeRemainingColumn, TimeElapsedColumn, TextColumn # initialize statistics dataset = MTTDataset(path=args.p_data, split=args.split) total_segments = len(dataset) n_samples_per_segment = args.n_samples sum_x, sum_x2 = 0, 0 n_samples_total = total_segments * n_samples_per_segment # iterating over dataset with Progress("[progress.description]{task.description}", BarColumn(), "[progress.percentage]{task.percentage:>3.0f}%", TimeRemainingColumn(), TextColumn("/"), TimeElapsedColumn(), "{task.completed} of {task.total} steps", expand=False, console=CONSOLE, refresh_per_second=5) as progress: task = progress.add_task(description='[Scanning Dataset] ', total=total_segments) for i in range(total_segments): sample, label = dataset[i] sum_x += np.sum(sample) sum_x2 += np.sum(sample**2) progress.update(task, advance=1) LOG.info("Calculating final results...") mean = sum_x / n_samples_total std = np.sqrt( (sum_x2 - sum_x**2 / n_samples_total) / (n_samples_total - 1)) LOG.info(f"Mean: {mean}\nStddev: {std}") return
def process_vantage_data(data): """Receives the data as a dictionary of info + values and return the two independent dictionaries""" metadata = data.get("Meta Data", None) if metadata: try: info = clean_enumeration(metadata) except Exception as err: LOG.ERROR(f"ERROR cleaning info: {metadata}") info = metadata else: info = {} data_key = [k for k in data.keys() if k != "Meta Data" ][0] # 'Time Series (Daily)' or 'Time Series FX (Weekly)' dat = data[data_key] return info, dat
def _process_audio_files(worker_id: int, tasks: pd.DataFrame, p_out: PosixPath, p_raw: PosixPath, n_samples: int = 59049, sample_rate: int = 22050, topk: int = 50, file_pattern: str = 'clip-{}-seg-{}-of-{}') -> None: n_tasks = tasks.shape[0] t_start = time.time() n_parts = n_tasks // 10 idx = 0 LOG.info(f"[Worker {worker_id:02d}]: Received {n_tasks} tasks.") for i, t in tasks.iterrows(): # find output dir split = t.split out_dir = p_out.joinpath(split) # process audio file try: segments = _segment_audio(_load_audio(p_raw.joinpath(t.mp3_path), sample_rate=sample_rate), n_samples=n_samples, center=False) loaded = True except (RuntimeError, EOFError) as e: LOG.warning(f"[Worker {worker_id:02d}]: Failed load audio: {t.mp3_path}. Ignored.") loaded = False # save label and segments to npy files if loaded: labels = t[t.index.tolist()[:topk]].values.astype(bool) n_segments = len(segments) for j, seg in enumerate(segments): np.savez_compressed(out_dir.joinpath(file_pattern.format(t.clip_id, j+1, n_segments)).as_posix(), data=seg, labels=labels) # report progress idx += 1 if idx == n_tasks: LOG.info(f"[Worker {worker_id:02d}]: Job finished. Quit. (time usage: {(time.time() - t_start) / 60:.02f} min)") elif idx % n_parts == 0: LOG.info(f"[Worker {worker_id:02d}]: {idx//n_parts*10}% tasks done. (time usage: {(time.time() - t_start) / 60:.02f} min)") return
async def query_data(symbol, category=None, api="vantage", verbose=VERBOSE, **kwargs): if category is None: raise ValueError("Please provide a valid category in the parameters") # Get semaphore semaphore_controller.get_semaphore(api) if verbose > 2: LOG.info("Successfully acquired the semaphore") if api == "vantage": url, params = alpha_vantage_query(symbol, category, key=KEYS_SET["alpha_vantage"], **kwargs) LOG.info( f"Retrieving {symbol}:{get_tabs(symbol, prev=12)}From '{api}' API") else: LOG.error(f"Not supported api {api}") counter = 0 while counter <= QUERY_RETRY_LIMIT: async with aiohttp.ClientSession() as session: async with session.get(url, params=params, headers=HEADERS) as resp: data = await resp.json() if api == "vantage": if manage_vantage_errors(data, symbol) == "longWait": counter += 1 await asyncio.sleep(VANTAGE_WAIT) else: break await asyncio.sleep(MIN_SEM_WAIT) if verbose > 2: LOG.info("Releasing Semaphore") # Release semaphore semaphore_controller.release_semaphore(api) return data
async def read_info_file(info_file, check=True, verbose=VERBOSE): if not info_file: return {} if info_file.exists(): async with aiofiles.open(info_file, "r") as info: data = await info.read() if verbose > 1: LOG.info(f"Info file read:{get_tabs('', prev=15)}{info_file}") return json.loads(data) else: if check: LOG.error(f"ERROR: No info found at {info_file}") if verbose > 1: LOG.warning(f"Info file: {info_file}\tDO NOT EXISTS!") return {}
async def update_stock(symbol, category="daily", max_gap=0, api="vantage", verbose=VERBOSE): folder_name, file_name = build_path_and_file(symbol, category) info_file = build_info_file(folder_name, category) info = None try: if folder_name.exists() and file_name.exists(): # Verify how much must be updated data_stored = read_pandas_data(file_name) first_date = data_stored.index[0] last_date = data_stored.index[-1] if delta_surpassed(last_date, max_gap, category): LOG.info(f"Updating {symbol} data...") # Retrieve only last range (alpha_vantage 100pts) data = await query_data(symbol, category=category, api="vantage", outputsize="compact") if data in [None, {}]: LOG.WARNING(f"No data received for {symbol}") return info, dat = process_vantage_data(data) info = add_first_ts(info, first_date) save_pandas_data(file_name, dat, old_data=data_stored, verbose=verbose) else: if verbose > 1: LOG.info( f"Updating {symbol}:{get_tabs(symbol, prev=10)}Ignored. Data {category} < {max_gap}d old" ) return else: # Download and save new data if verbose > 1: LOG.info(f"Updating {symbol} ...") data = await query_data(symbol, semaphore, category=category, api=api) if data in [None, {}]: LOG.WARNING(f"No data received for {symbol}") return info, dat = process_vantage_data(data) save_pandas_data(file_name, dat, verbose=verbose) # Save/Update info if info: await update_stock_info(info_file, info) if verbose > 1: LOG.info(f"Updating {symbol}:{get_tabs(symbol, prev=10)}Finished") except Exception as err: LOG.info( f"Updating {symbol}:{get_tabs(symbol, prev=10)}ERROR: {err.__repr__()} {traceback.print_tb(err.__traceback__)}" )
def test_on_model(args): device = args.device if device == 'cpu': raise NotImplementedError("CPU training is not implemented.") device = torch.device(args.device) torch.cuda.set_device(device) # build model model = build_model(args) model.to(device) # output dir p_out = Path( args.p_out).joinpath(f"{model.name}-{args.tensorboard_exp_name}") if not p_out.exists(): p_out.mkdir(exist_ok=True, parents=True) # dataset & loader test_dataset = MTTDataset(path=args.p_data, split='test') test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.n_workers, pin_memory=True, drop_last=False) # not dropping last in testing test_steps = test_dataset.calc_steps( args.batch_size, drop_last=False) # not dropping last in testing LOG.info(f"Total testing steps: {test_steps}") LOG.info(f"Testing data size: {len(test_dataset)}") # create loss loss_fn = get_loss(args.loss) # create metric metric = AUCMetric() # load checkpoint OR init state_dict if args.checkpoint is not None: state_dict = load_ckpt(args.checkpoint, reset_epoch=args.ckpt_epoch, no_scheduler=args.ckpt_no_scheduler, no_optimizer=args.ckpt_no_optimizer, no_loss_fn=args.ckpt_no_loss_fn, map_values=args.ckpt_map_values) model_dict = {'model': model} if 'model' in state_dict else None apply_state_dict(state_dict, model=model_dict) best_val_loss = state_dict['val_loss'] epoch = state_dict['epoch'] global_i = state_dict['global_i'] LOG.info( f"Checkpoint loaded. Epoch trained {epoch}, global_i {global_i}, best val {best_val_loss:.6f}" ) else: raise AssertionError("Pre-trained checkpoint must be provided.") # summary writer writer = SummaryWriter(log_dir=p_out.as_posix(), filename_suffix='-test') # start testing model.eval() sigmoid = Sigmoid().to(device) status_col = TextColumn("") running_loss = 0 if args.data_normalization: fetcher = DataPrefetcher(test_loader, mean=MTT_MEAN, std=MTT_STD, device=device) else: fetcher = DataPrefetcher(test_loader, mean=None, std=None, device=device) samples, targets = fetcher.next() with Progress("[progress.description]{task.description}", "[{task.completed}/{task.total}]", BarColumn(), "[progress.percentage]{task.percentage:>3.0f}%", TimeRemainingColumn(), TextColumn("/"), TimeElapsedColumn(), status_col, expand=False, console=CONSOLE, refresh_per_second=5) as progress: task = progress.add_task(description=f'[Test]', total=test_steps) i = 0 # counter t_start = time.time() with torch.no_grad(): while samples is not None: # forward model logits = model(samples) out = sigmoid(logits) test_loss = loss_fn(logits, targets) # collect running loss running_loss += test_loss.item() i += 1 writer.add_scalar('Test/Loss', running_loss / i, i) # auc metric metric.step(targets.cpu().numpy(), out.cpu().numpy()) # pre-fetch next samples samples, targets = fetcher.next() if not progress.finished: status_col.text_format = f"Test loss: {running_loss/i:.06f}" progress.update(task, advance=1) auc_tag, auc_sample, ap_tag, ap_sample = metric.auc_ap_score LOG.info(f"Testing speed: {(time.time() - t_start)/i:.4f}s/it, " f"auc_tag: {auc_tag:.04f}, " f"auc_sample: {auc_sample:.04f}, " f"ap_tag: {ap_tag:.04f}, " f"ap_sample: {ap_sample:.04f}") writer.close() return
def prepare_MTT_dataset(args): CONSOLE.rule("Pre-processing MTT Annotations and Data for Machine Learning") # --- get dirs --- # create out dir if not exists p_out = Path(args.p_out).absolute() while True: if p_out.exists(): res = CONSOLE.input(f"Output folder exists ({p_out.as_posix()})! Do you want to remove it first? " f"(You can also clean it manually now and hit enter key to retry) [y/n]: ") if res.lower() in ['y', 'yes']: # delete target folder shutil.rmtree(p_out) # create new one p_out.mkdir() LOG.info(f"Target folder removed, and new empty folder created.") break elif res.lower() in ['n', 'no']: LOG.error(f"Output folder exists! Creating folder failed. Target: {p_out.as_posix()} exists.") raise FileExistsError(f"Output folder exists! Creating folder failed. Target: {p_out.as_posix()} exists.") else: continue else: p_out.mkdir() LOG.info(f"Target folder ({p_out.as_posix()}) created.") break # train/val/test dirs p_out.joinpath('train').mkdir() p_out.joinpath('val').mkdir() p_out.joinpath('test').mkdir() # check raw data p_raw = Path(args.p_raw).absolute() assert len(list(p_raw.glob('[0-9, a-z]'))) == 16, "MTT Raw data should have 16 directories from 0-9 and a-f." # --- parsing and processing annotations --- annotations = process_MTT_annotations(p_anno=args.p_anno, p_info=args.p_info, delimiter=args.delimiter, n_top=args.n_topk) # save processed annotations annotations.to_csv(Path(args.p_anno).parent.joinpath(f'annotations_top{args.n_topk}.csv').as_posix(), index=False) # save topk labels with open(Path(args.p_anno).parent.joinpath('labels.txt').as_posix(), 'w') as f: f.write(','.join(annotations.columns.tolist()[:args.n_topk])) CONSOLE.rule("Audio Preprocessing") LOG.info(f"MTT annotations processed. Now segmenting audios based on annotations for machine learning...") # --- process audio files based on annotations --- avg = annotations.shape[0] // args.n_worker processes = [Process(target=_process_audio_files, args=(i, annotations.iloc[i*avg:(i+1)*avg], p_out, p_raw, args.n_samples, args.sr, args.n_topk)) if i != args.n_worker-1 else Process(target=_process_audio_files, args=(i, annotations.iloc[i*avg:], p_out, p_raw, args.n_samples, args.sr, args.n_topk)) for i in range(args.n_worker)] LOG.info(f"{args.n_worker} workers created.") # start jobs for p in processes: p.start() # wait jobs to finish for p in processes: p.join() CONSOLE.rule('MTT Dataset Preparation Done') return
def process_MTT_annotations(p_anno: str = 'annotations_final.csv', p_info: str = 'clip_info_final.csv', delimiter: str = '\t', n_top: int = 50) -> pd.DataFrame: """ Reads annotation file, takes top N tags, and splits data samples Results 55 (top50_tags + [clip_id, mp3_path, split, shard]) columns: ['guitar', 'classical', 'slow', 'techno', 'strings', 'drums', 'electronic', 'rock', 'fast', 'piano', 'ambient', 'beat', 'violin', 'vocal', 'synth', 'female', 'indian', 'opera', 'male', 'singing', 'vocals', 'no vocals', 'harpsichord', 'loud', 'quiet', 'flute', 'woman', 'male vocal', 'no vocal', 'pop', 'soft', 'sitar', 'solo', 'man', 'classic', 'choir', 'voice', 'new age', 'dance', 'female vocal', 'male voice', 'beats', 'harp', 'cello', 'no voice', 'weird', 'country', 'female voice', 'metal', 'choral', 'clip_id', 'mp3_path', 'split', 'title', 'artist'] NOTE: This will exclude audios which have only zero-tags. Therefore, number of each split will be 15250 / 1529 / 4332 (training / validation / test). :param p_anno: A path to annotation CSV file :param p_info: A path to song info CSV file :param delimiter: csv delimiter :param n_top: Number of the most popular tags to take :return: A DataFrame contains information of audios Schema: <tags>: 0 or 1 clip_id: clip_id of the original dataset mp3_path: A path to a mp3 audio file split: A split of dataset (training / validation / test). The split is determined by its directory (0, 1, ... , f). First 12 directories (0 ~ b) are used for training, 1 (c) for validation, and 3 (d ~ f) for test. title: title of the song that the clip is in artist: artist of the song that the clip is in """ def split_by_directory(mp3_path: str) -> str: directory = mp3_path.split('/')[0] part = int(directory, 16) if part in range(12): return 'train' elif part == 12: return 'val' elif part in range(13, 16): return 'test' LOG.info(f"Starting pre-processing MTT annotations, " f"keeping {n_top} top tags, " f"and merging song info.") # read csv annotation df_anno = pd.read_csv(p_anno, delimiter=delimiter) df_info = pd.read_csv(p_info, delimiter=delimiter) LOG.info(f"Loaded annotations from [bold]{p_anno}[/], " f"loaded song info from [bold]{p_info}[/], " f"which contains {df_anno.shape[0]} songs.") # get top50 tags top50 = df_anno.drop(['clip_id', 'mp3_path'], axis=1)\ .sum()\ .sort_values(ascending=False)[:n_top]\ .index\ .tolist() LOG.info(f"TOP 50 Tags:\n{top50}") # remove low frequency tags df_anno = df_anno[top50 + ['clip_id', 'mp3_path']] # remove songs that have 0 tag df_anno = df_anno[df_anno[top50].sum(axis=1) != 0] # creating train/val/test splits df_anno['split'] = df_anno['mp3_path'].transform(split_by_directory) # show splits for split in ['train', 'val', 'test']: LOG.info(f"{split} set size (#audio): {sum(df_anno['split'] == split)}.") # merge annotations and song info df_merge = pd.merge(df_anno, df_info[['clip_id', 'title', 'artist']], on='clip_id') LOG.info(f"Final quantity of songs: {df_merge.shape[0]}\nFinal columns ({df_merge.columns.size}) in the DataFrame:\n{df_merge.columns.tolist()}") return df_merge
def eval_on_model(args): device = args.device if device == 'cpu': raise NotImplementedError("CPU training is not implemented.") device = torch.device(args.device) torch.cuda.set_device(device) # build model model = build_model(args) model.to(device) # output dir p_out = Path( args.p_out).joinpath(f"{model.name}-{args.tensorboard_exp_name}") if not p_out.exists(): p_out.mkdir(exist_ok=True, parents=True) # dataset & loader annotation = pd.read_csv(args.annotation_file) query = annotation[annotation.mp3_path.str.match('/'.join( args.audio_file.split('/')[-2:]))] assert query.shape[0] != 0, f"Cannot find the audio file: {args.audio_file}" # split audio info and segment audio threshold = args.eval_threshold song_info = query[query.columns.values[50:]] tags = query.columns.values[:50] labels = query[tags].values[0] label_names = tags[labels.astype(bool)] segments = _segment_audio(_load_audio(args.audio_file, sample_rate=22050), n_samples=59049) LOG.info(f"Song info: {song_info}") LOG.info(f"Number of segments: {len(segments)}") LOG.info(f"Ground truth tags: {label_names}") LOG.info(f"Positive tag threshold: {threshold}") # create loss loss_fn = get_loss(args.loss) # load checkpoint OR init state_dict if args.checkpoint is not None: state_dict = load_ckpt(args.checkpoint, reset_epoch=args.ckpt_epoch, no_scheduler=args.ckpt_no_scheduler, no_optimizer=args.ckpt_no_optimizer, no_loss_fn=args.ckpt_no_loss_fn, map_values=args.ckpt_map_values) model_dict = {'model': model} if 'model' in state_dict else None apply_state_dict(state_dict, model=model_dict) best_val_loss = state_dict['val_loss'] epoch = state_dict['epoch'] global_i = state_dict['global_i'] LOG.info( f"Checkpoint loaded. Epoch trained {epoch}, global_i {global_i}, best val {best_val_loss:.6f}" ) else: raise AssertionError("Pre-trained checkpoint must be provided.") # start testing model.eval() sigmoid = Sigmoid().to(device) t_start = time.time() # concatenate segments segments = torch.from_numpy( np.concatenate([seg.reshape(1, 1, -1) for seg in segments ])).to(torch.float32).cuda(device=device) targets = torch.from_numpy(np.concatenate( [labels.reshape(1, -1)] * 10)).to(torch.float32).cuda(device=device) # forward pass with torch.no_grad(): logits = model(segments) out = sigmoid(logits) loss = loss_fn(logits, targets) out = out.cpu().numpy() out[out > threshold] = 1 out[out <= threshold] = 0 out = np.sum(out, axis=0) res = pd.DataFrame(data={'tags': tags, 'freq': out}) res = res[res.freq != 0].sort_values(by='freq', ascending=False) CONSOLE.print(res) LOG.info(f"Testing speed: {time.time() - t_start:.4f}s, " f"loss: {loss.item()}, ") return
def __init__(self, block: Type[Bottleneck], layers: List[int], channels: List[int] = (64, 128, 256, 512), in_channels: int = 1, expansion: int = 2, n_class: int = 50, zero_init_residual: bool = False, name: Optional[str] = None, robust: bool = False): super(ExpSampleCNN, self).__init__() assert len(layers) == len( channels), "layers and channels length mismatch!" self.n_modules = len(layers) self.n_layers = sum(layers) * 3 + 2 self._norm_layer = nn.BatchNorm1d self.in_planes = 64 # input head self.conv1 = nn.Conv1d(in_channels, self.in_planes, kernel_size=3, stride=3, padding=1, bias=False) self.bn1 = self._norm_layer(self.in_planes) self.relu = nn.ReLU() self.maxpool = nn.MaxPool1d(kernel_size=3, stride=3, padding=1) # layers for i in range(len(layers)): if i == 0: setattr( self, f'layer{i+1}', self._make_layer(block, channels[i], layers[i], stride=1, expansion=expansion)) else: setattr( self, f'layer{i+1}', self._make_layer(block, channels[i], layers[i], stride=3, expansion=expansion)) # fc self.avgpool = nn.AdaptiveAvgPool1d(1) self.fc = nn.Linear(channels[-1] * expansion, n_class) # init params for m in self.modules(): if isinstance(m, nn.Conv1d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, nn.BatchNorm1d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) # Zero-initialize the last BN in each residual branch # so that the residual branch starts with zeros, and each residual block behaves like an identity if zero_init_residual and isinstance(m, Bottleneck): nn.init.constant_(m.bn3.weight, 0) # name self.name = 'ExpSampleCNN' if name is None else name if robust: LOG.info( f"{self.name}\nmodel modules: {self.n_modules}\nmodel layers: {self.n_layers}\nmodel params: {self.n_params}" ) return
def train_on_model(args): if args.device == 'cpu': raise NotImplementedError("CPU training is not implemented.") device = torch.device(args.device) torch.cuda.set_device(device) # build model model = build_model(args) model.to(device) # output dir p_out = Path(args.p_out).joinpath(f"{model.name}-{args.tensorboard_exp_name}") if not p_out.exists(): p_out.mkdir(exist_ok=True, parents=True) # dataset & loader train_dataset = MTTDataset(path=args.p_data, split='train') val_dataset = MTTDataset(path=args.p_data, split='val') train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.n_workers, pin_memory=True, drop_last=True) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.n_workers, pin_memory=True, drop_last=True) train_steps = train_dataset.calc_steps(args.batch_size) val_steps = val_dataset.calc_steps(args.batch_size) if args.data_normalization: normalize = (MTT_MEAN, MTT_STD) LOG.info("Data normalization [bold cyan]on[/]") else: normalize = None LOG.info(f"Total training steps: {train_steps}") LOG.info(f"Total validation steps: {val_steps}") LOG.info(f"Training data size: {len(train_dataset)}") LOG.info(f"Validation data size: {len(val_dataset)}") # create optimizer optim = get_optimizer(model.parameters(), args=args) # create loss loss_fn = get_loss(args.loss) # creating scheduler scheduler_plateau = ReduceLROnPlateau(optim, factor=args.lr_decay_plateau, patience=args.plateau_patience, min_lr=args.min_lr, verbose=True, prefix="[Scheduler Plateau] ", logger=LOG) scheduler_es = EarlyStopping(patience=args.early_stop_patience, min_delta=args.early_stop_delta, verbose=True, prefix="[Scheduler Early Stop] ", logger=LOG) # load checkpoint OR init state_dict if args.checkpoint is not None: state_dict = load_ckpt(args.checkpoint, reset_epoch=args.ckpt_epoch, no_scheduler=args.ckpt_no_scheduler, no_optimizer=args.ckpt_no_optimizer, no_loss_fn=args.ckpt_no_loss_fn, map_values=args.ckpt_map_values) model_dict = {'model': model} if 'model' in state_dict else None optim_dict = {'optim': optim} if 'optim' in state_dict else None loss_fn_dict = {'loss_fn': loss_fn} if 'loss_fn' in state_dict else None scheduler_dict = {'scheduler_plateau': scheduler_plateau} \ if 'scheduler_plateau' in state_dict else None apply_state_dict(state_dict, model=model_dict, optimizer=optim_dict, loss_fn=loss_fn_dict, scheduler=scheduler_dict) best_val_loss = state_dict['val_loss'] epoch = state_dict['epoch'] global_i = state_dict['global_i'] LOG.info(f"Checkpoint loaded. Epoch trained {epoch}, global_i {global_i}, best val {best_val_loss:.6f}") else: # fresh training best_val_loss = 9999 epoch = 0 global_i = 0 # tensorboard purge_step = None if global_i == 0 else global_i writer = SummaryWriter(log_dir=VAR .log .joinpath('tensorboard') .joinpath(f"{model.name}-{args.tensorboard_exp_name}") .as_posix(), purge_step=purge_step, filename_suffix='-train') # train model for epochs assert epoch < args.max_epoch, "Initial epoch value must be smaller than max_epoch in order to train model" for i in range(epoch, args.max_epoch): # train init_lr = optim.param_groups[0]['lr'] train_loss, global_i = train_one_epoch(model, optim, loss_fn, train_loader, epoch+1, train_steps, device, writer, global_i, writer_interval=args.tensorboard_interval, normalize=normalize) # validate val_loss = evaluate(model, loss_fn, val_loader, epoch+1, val_steps, device, normalize=normalize) writer.add_scalar('Loss/Val', val_loss, global_i) epoch += 1 # update scheduler scheduler_plateau.step(val_loss) scheduler_es.step(val_loss) # save checkpoint if optim.param_groups[0]['lr'] != init_lr: LOG.info(f"Saving [red bold]checkpoint[/] at epoch {epoch}, model saved to {p_out.as_posix()}") torch.save({ 'model': model.state_dict(), 'optim': optim.state_dict(), 'loss_fn': loss_fn.state_dict(), 'scheduler_plateau': scheduler_plateau.state_dict(), 'scheduler_es': scheduler_es.state_dict(), 'epoch': epoch, 'loss': train_loss, 'val_loss': val_loss, 'p_out': p_out, 'global_i': global_i }, p_out.joinpath(f'ckpt@epoch-{epoch:03d}-loss-{val_loss:.6f}.tar').as_posix()) # save the best model if val_loss < best_val_loss: best_val_loss = val_loss LOG.info(f"New [red bold]best[/] validation loss {val_loss:.6f}, model saved to {p_out.as_posix()}") torch.save({ 'model': model.state_dict(), 'optim': optim.state_dict(), 'loss_fn': loss_fn.state_dict(), 'scheduler_plateau': scheduler_plateau.state_dict(), 'scheduler_es': scheduler_es.state_dict(), 'epoch': epoch, 'loss': train_loss, 'val_loss': val_loss, 'p_out': p_out, 'global_i': global_i }, p_out.joinpath(f'best@epoch-{epoch:03d}-loss-{val_loss:.6f}.tar').as_posix()) # save latest model else: torch.save({ 'model': model.state_dict(), 'optim': optim.state_dict(), 'loss_fn': loss_fn.state_dict(), 'scheduler_plateau': scheduler_plateau.state_dict(), 'scheduler_es': scheduler_es.state_dict(), 'epoch': epoch, 'loss': train_loss, 'val_loss': val_loss, 'p_out': p_out, 'global_i': global_i }, p_out.joinpath(f'latest.tar').as_posix()) # early stop, if enabled if scheduler_es.early_stop: break # if load optimal model when lr changed if optim.param_groups[0]['lr'] != init_lr and args.load_optimal_on_plateau: # save lr before restoring cur_lr = [param_group['lr'] for param_group in optim.param_groups] # restore last best model state_dict = find_optimal_model(p_out) apply_state_dict(state_dict, model={'model': model}, optimizer={'optim': optim}, loss_fn=None, scheduler=None) apply_lr(optim, cur_lr) # reset global_i global_i = state_dict['global_i'] epoch = state_dict['epoch'] LOG.info(f"Best model (val loss {state_dict['val_loss']}) applied. Roll back to epoch {epoch}") # reset tensorboard writer writer.close() writer = SummaryWriter(log_dir=VAR .log .joinpath('tensorboard') .joinpath(f"{model.name}-{args.tensorboard_exp_name}") .as_posix(), purge_step=global_i, filename_suffix='-train') # close tensorboard writer.close() return