def validate(datastream, cpc, model, num_emotions): losses = [] model.eval() # Stash and later restore states for non-leaky validation cpc.stash_state() model.stash_state() # reset to a fixed random seed for determisitic and comparable validation with FixedRandomState(42): for step, batch in enumerate(datastream): data, labels = batch["data"].to(device), batch["labels"] with torch.no_grad(): features = cpc(data) pred = model(features) labels = resample_1d(labels, pred.shape[1]) pred = pred.reshape(-1, num_emotions) labels = labels.reshape(-1) losses.append(F.cross_entropy(pred, labels.to(device)).item()) if step >= FLAGS.valid_steps: break cpc.pop_state() model.pop_state() model.train() return np.array(losses).mean()
def train(unused_argv): set_seeds(FLAGS.seed) write_current_pid(FLAGS.expdir) # setup logging tb_logger = prepare_tb_logging() prepare_standard_logging("training") loss_dir = Path(f"{FLAGS.expdir}/losses") loss_dir.mkdir(exist_ok=True) train_losses_fh = open(loss_dir / "train.txt", "a", buffering=1) valid_losses_fh = open(loss_dir / "valid.txt", "a", buffering=1) if FLAGS.dry_run is True: setup_dry_run(FLAGS) if not FLAGS.model_out: FLAGS.model_out = FLAGS.expdir + "/model.pt" if not FLAGS.checkpoint_out: FLAGS.checkpoint_out = FLAGS.expdir + "/checkpoint.pt" if FLAGS.checkpoint_autoload is True and not FLAGS.checkpoint: FLAGS.checkpoint = get_checkpoint_to_start_from(FLAGS.checkpoint_out) logging.info(f"autosetting checkpoint: {FLAGS.checkpoint}") if FLAGS.cpc_path is not None: cpc = load_model(FLAGS.cpc_path).to(device) cpc.reset_state() else: cpc = NoCPC() cpc.eval() # write information about cpc into metadata with open(f"{FLAGS.expdir}/metadata.txt", "a") as fh: fh.write(f"sampling_rate_hz {cpc.data_class.SAMPLING_RATE_HZ}\n") fh.write(f"feat_dim {cpc.feat_dim}\n") # define training data parsed_train_dbl = parse_emotion_dbl(FLAGS.train_data) train_streams = [ DblStream( DblSampler(parsed_train_dbl), EmotionIDSingleFileStream, FLAGS.window_size, emotion_set_path=FLAGS.emotion_set_path, audiostream_class=cpc.data_class, ) for _ in range(FLAGS.batch_size) ] train_datastream = MultiStreamDataLoader(train_streams, device=device) # define validation data parsed_valid_dbl = parse_emotion_dbl(FLAGS.val_data) parsed_test_dbl = parse_emotion_dbl(FLAGS.test_data) val_streams = [ DblStream( DblSampler(parsed_valid_dbl), EmotionIDSingleFileStream, FLAGS.window_size, emotion_set_path=FLAGS.emotion_set_path, audiostream_class=cpc.data_class, # TODO ensure un-augmented stream ) for _ in range(FLAGS.batch_size) ] valid_datastream = MultiStreamDataLoader(val_streams, device=device) if not FLAGS.val_every: FLAGS.val_every = max(100, FLAGS.steps // 50) if not FLAGS.save_every: FLAGS.save_every = FLAGS.val_every if not FLAGS.valid_steps: FLAGS.valid_steps = max(20, FLAGS.val_every // 100) valid_frames = FLAGS.batch_size * FLAGS.window_size * FLAGS.valid_steps feat_dim = cpc.feat_dim num_emotions = len(get_emotion_to_id_mapping(FLAGS.emotion_set_path)) if FLAGS.model == "linear": model = LinearEmotionIDModel(feat_dim, num_emotions).to(device) elif FLAGS.model == "baseline": model = BaselineEmotionIDModel(feat_dim, num_emotions).to(device) elif FLAGS.model == "mlp2": model = MLPEmotionIDModel( feat_dim, num_emotions, no_layers=2, hidden_size=FLAGS.hidden_size, dropout_prob=FLAGS.dropout_prob, batch_norm_on=FLAGS.batch_norm, ).to(device) elif FLAGS.model == "mlp4": model = MLPEmotionIDModel( feat_dim, num_emotions, no_layers=4, hidden_size=FLAGS.hidden_size, dropout_prob=FLAGS.dropout_prob, batch_norm_on=FLAGS.batch_norm, ).to(device) elif FLAGS.model == "conv": model = ConvEmotionIDModel( feat_dim, num_emotions, no_layers=4, hidden_size=FLAGS.hidden_size, dropout_prob=FLAGS.dropout_prob, ).to(device) elif FLAGS.model == "rnn": model = RecurrentEmotionIDModel( feat_dim=feat_dim, num_emotions=num_emotions, bidirectional=False, hidden_size=FLAGS.hidden_size, dropout_prob=FLAGS.dropout_prob, ).to(device) elif FLAGS.model == "rnn_bi": model = RecurrentEmotionIDModel( feat_dim=feat_dim, num_emotions=num_emotions, bidirectional=True, hidden_size=FLAGS.hidden_size, dropout_prob=FLAGS.dropout_prob, ).to(device) elif FLAGS.model == "wavenet": model = WaveNetEmotionIDModel(feat_dim, num_emotions).to(device) padding_percentage = 100 * model.max_padding / FLAGS.window_size logging.info(f"max padding {model.max_padding}, percentage {padding_percentage}%") logging.info(f"receptve field {model.receptive_field}") elif FLAGS.model == "wavenet_unmasked": model = WaveNetEmotionIDModel(feat_dim, num_emotions, masked=False).to(device) padding_percentage = 100 * model.max_padding / FLAGS.window_size logging.info(f"max padding {model.max_padding}, percentage {padding_percentage}%") logging.info(f"receptve field {model.receptive_field}") else: raise NameError("Model name not found") logging.info(f"number of classes {num_emotions}") logging.info(f"model param count {sum(x.numel() for x in model.parameters()):,}") optimizer = RAdam(model.parameters(), eps=1e-05, lr=FLAGS.lr) if FLAGS.lr_schedule: scheduler = FlatCA(optimizer, steps=FLAGS.steps, eta_min=0) else: scheduler = EmptyScheduler(optimizer) step = 0 optimizer.best_val_loss = inf if FLAGS.checkpoint: # loading state_dicts in-place load_checkpoint(FLAGS.checkpoint, model, optimizer, scheduler=scheduler) step = optimizer.restored_step dump_checkpoint_on_kill(model, optimizer, scheduler, FLAGS.checkpoint_out) set_seeds(FLAGS.seed + step) model.train() for batch in train_datastream: data, labels = batch["data"].to(device), batch["labels"] features = cpc(data) pred = model(features) labels = resample_1d(labels, pred.shape[1]).reshape(-1).to(device) # get cross entropy loss against emotion labels and take step optimizer.zero_grad() output = model(features).reshape(-1, num_emotions) loss = F.cross_entropy(output, labels) loss.backward() clip_grad_norm_(model.parameters(), FLAGS.clip_thresh) optimizer.step() scheduler.step() # log training losses logging.info(f"{step} train steps, loss={loss.item():.5}") tb_logger.add_scalar("train/loss", loss, step) train_losses_fh.write(f"{step}, {loss.item()}\n") tb_logger.add_scalar("train/lr", scheduler.get_lr()[0], step) # validate periodically if step % FLAGS.val_every == 0 and step != 0: valid_loss = validate(valid_datastream, cpc, model, num_emotions) # log validation losses logging.info( f"{step} validation, loss={valid_loss.item():.5}, " f"{valid_frames:,} items validated" ) tb_logger.add_scalar("valid/loss", valid_loss, step) valid_losses_fh.write(f"{step}, {valid_loss}\n") val_results = validate_filewise(parsed_valid_dbl, cpc, model, num_emotions) test_results = validate_filewise(parsed_test_dbl, cpc, model, num_emotions) for results, dataset in zip([val_results, test_results], ["valid", "test"]): tb_logger.add_scalar(f"{dataset}/full_loss", results["average_loss"], step) for name in ["framewise", "filewise"]: cm = fig2tensor(results[name]["confusion_matrix"]) tb_logger.add_scalar(f"{dataset}/accuracy", results[name]["accuracy"], step) tb_logger.add_scalar(f"{dataset}/f1_score", results[name]["average_f1"], step) tb_logger.add_image(f"{dataset}/confusion_matrix", cm, step) for emotion, f1 in val_results["framewise"]["class_f1"].items(): tb_logger.add_scalar(f"f1/{emotion}", f1, step) if valid_loss.item() < optimizer.best_val_loss: logging.info("Saving new best validation") save(model, FLAGS.model_out + ".bestval") optimizer.best_val_loss = valid_loss.item() # save out model periodically if step % FLAGS.save_every == 0 and step != 0: save(model, FLAGS.model_out + ".step" + str(step)) if step >= FLAGS.steps: break step += 1 save(model, FLAGS.model_out) # close loss logging file handles train_losses_fh.close() valid_losses_fh.close()
def validate_filewise(dbl, cpc, model, num_emotions): logging.info("Starting filewise validation") losses = [] frame_preds = [] frame_refs = [] file_preds = [] file_refs = [] model.eval() # Stash and later restore states for non-leaky validation cpc.stash_state() model.stash_state() # loop over each dbl for i, dbl_entry in enumerate(dbl): # file specific stream to iterate over stream = EmotionIDSingleFileStream( dbl_entry, FLAGS.window_size, FLAGS.emotion_set_path, audiostream_class=cpc.data_class ) single_file_preds = [] for j, batch in enumerate(stream): with torch.no_grad(): data = torch.tensor(batch["data"]).unsqueeze(0).to(device) labels = torch.tensor(batch["labels"]).unsqueeze(0) # get predictions features = cpc(data) logits = model(features) # get pred pred = logits.argmax(dim=2).squeeze(dim=0) frame_preds.append(pred) single_file_preds.append(pred) if logits.shape[1] > 1: labels = resample_1d(labels, logits.shape[1]) labels = labels.reshape(-1) frame_refs.append(labels) # get loss logits = logits.reshape(-1, num_emotions) losses.append(F.cross_entropy(logits, labels.to(device)).item()) counts = np.bincount(torch.cat(single_file_preds, dim=0).cpu().numpy()) file_preds.append(np.argmax(counts)) file_refs.append(labels[-1]) frame_preds = torch.cat(frame_preds, dim=0).cpu().numpy() frame_refs = torch.cat(frame_refs, dim=0).cpu().numpy() file_preds = np.array(file_preds) file_refs = np.array(file_refs) results = {} results["average_loss"] = np.array(losses).mean() emotion2id = get_emotion_to_id_mapping(FLAGS.emotion_set_path) for refs, preds, name in zip( [frame_refs, file_refs], [frame_preds, file_preds], ["framewise", "filewise"] ): results[name] = {} results[name]["accuracy"] = accuracy_score(refs, preds) results[name]["average_f1"] = f1_score(refs, preds, average="macro") results[name]["class_f1"] = {} f1_scores = f1_score(refs, preds, average=None) for f1, emotion in zip(f1_scores, emotion2id.keys()): results[name]["class_f1"][emotion] = f1 cm = confusion_matrix(refs, preds) fig, ax = plt.subplots(figsize=(8, 6), dpi=150) sns.heatmap(cm, annot=True, ax=ax) ax.set_xlabel("Predicted labels") ax.set_ylabel("True labels") results[name]["confusion_matrix"] = fig cpc.pop_state() model.pop_state() model.train() return results