def train(train_data, video_train_data, eval_data, video_eval_data, test_data, video_test_data, model, batch_size, num_epochs, model_name): model = model.train() weights = [680/261, 680/419] class_weights = torch.tensor(weights, device=device) criterion = nn.CrossEntropyLoss(weight=class_weights).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.001) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2, factor=0.5, cooldown=0, min_lr=1e-6, verbose=True) best_valid_loss = float('inf') best_eval_acc = 0.0 eval_pat = 10 dev_loss_list = [] train_loss_list = [] train_acc_list = [] dev_acc_list = [] step_num = 0 for epoch_id in range(num_epochs): for batch_features, video_batch_features, batch_labels in read_batches(train_data, video_train_data, batch_size, True, device, is_multi=False): model = model.train() optimizer.zero_grad() step_num+=1 batch_features = batch_features.permute((0,2,1)) out = model(video_batch_features) outs = [out] losses = [] for i, out in enumerate(outs): losses.append(criterion( out, # (batch_size , num_classes) batch_labels[:, i:i + 1].view(-1) # (batch_size * 1) )) loss = sum(losses) loss.backward() optimizer.step() if step_num % eval_pat == 0: # print("plot now") train_loss, train_bal, train_f, train_acc = evaluate(model, train_data, video_train_data, batch_size, criterion) dev_loss, dev_bal, dev_f, dev_acc = evaluate(model, eval_data, video_eval_data, batch_size, criterion) train_loss_list.append(train_loss) dev_loss_list.append(dev_loss) train_acc_list.append(train_bal) dev_acc_list.append(dev_bal) plot_losses(model_name, train_loss_list, dev_loss_list, train_acc_list, dev_acc_list, step_num) train_loss, train_bal, train_f, train_acc = evaluate(model, train_data, video_train_data, batch_size, criterion) dev_loss, dev_bal, dev_f, dev_acc = evaluate(model, eval_data, video_eval_data, batch_size, criterion) scheduler.step(dev_loss) print() print(f'TRAIN Epoch {epoch_id} | balanced train. accuracy={train_bal} | train fscore={train_f} | train_loss={ff(train_loss,n=8)} | train. accuracy={train_acc} |') print(f'Epoch {epoch_id} | balanced dev. accuracy={dev_bal} | dev fscore={dev_f} | dev_loss={ff(dev_loss, n=8)} | dev. accuracy={dev_acc} |') if dev_loss < best_valid_loss: best_valid_loss = dev_loss torch.save(model.state_dict(), model_name+'.pt') model.load_state_dict(torch.load(model_name+'.pt')) dev_loss, dev_bal, dev_f, dev_acc = evaluate(model, test_data, video_test_data, batch_size, criterion) print(f'TEST - balanced test. accuracy={dev_bal} | test fscore={dev_f} | test_loss={ff(dev_loss, n=8)} | test acc={dev_acc} |') plot_losses(model_name, train_loss_list, dev_loss_list, train_acc_list, dev_acc_list, step_num, test_acc=dev_bal, test_loss=dev_loss)
def train(train_data_phys, video_train_data, eval_data, video_eval_data, test_data, video_test_data, model, batch_size, num_epochs, model_name): model = model.train() criterion = nn.CrossEntropyLoss().to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.0001) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=0, factor=0.33, cooldown=0, min_lr=1e-6, verbose=True) best_valid_loss = float('inf') best_eval_acc = 0.0 for epoch_id in range(num_epochs): model = model.train() for batch_features, video_batch_features, batch_labels in read_batches( train_data_phys, video_train_data, batch_size, True, device): optimizer.zero_grad() batch_features = batch_features.permute((0, 2, 1)) out = model(video_batch_features, batch_features) loss = criterion( out, # (batch_size , num_classes) batch_labels.view(-1) # (batch_size * 1) ) loss.backward() optimizer.step() train_loss, train_acc, train_bal, train_f = evaluate( model, train_data, video_train_data, batch_size, criterion) dev_loss, dev_acc, dev_bal, dev_f = evaluate(model, eval_data, video_eval_data, batch_size, criterion) scheduler.step(dev_loss) print( f'TRAIN Epoch {epoch_id} | balanced train. accuracy={train_bal} | train fscore={train_f} | train_loss={ff(train_loss,n=8)} | train. accuracy={train_acc} |' ) print( f'Epoch {epoch_id} | balanced dev. accuracy={dev_bal} | dev fscore={dev_f} | dev_loss={ff(dev_loss, n=8)} | dev. accuracy={dev_acc} |' ) if dev_loss < best_valid_loss: best_valid_loss = dev_loss torch.save(model.state_dict(), model_name + '.pt') model.load_state_dict(torch.load(model_name + '.pt')) dev_loss, dev_acc, dev_bal, dev_f = evaluate(model, test_data, video_test_data, batch_size, criterion) print( f'TEST - balanced test. accuracy={dev_bal} | test fscore={dev_f} | test_loss={ff(dev_loss, n=8)} | test acc={dev_acc} |' )
def evaluate(model, eval_dataset, eval_dataset_video, batch_size, criterion): # print("eval called") model = model.eval() epoch_loss = 0 nc = 1 epoch_acc_dict = { i:0 for i in range(nc)} all_tp, all_fp, all_tn, all_fn = { i:0 for i in range(nc)}, { i:0 for i in range(nc)}, { i:0 for i in range(nc)}, { i:0 for i in range(nc)} num_batches = 0 all_labels = { i:[] for i in range(nc)} all_preds = { i:[] for i in range(nc)} for batch_features, video_batch_features, batch_labels in read_batches(eval_dataset, eval_dataset_video, batch_size, False, device, is_multi=False): batch_features = batch_features.permute((0, 2, 1)) out, pred_classes = model.predict(video_batch_features) outs = [out] pred_classes=[pred_classes] losses = [] for i, tup in enumerate(zip(outs, pred_classes)): out, pred_class = tup losses.append(criterion( out, # (batch_size , num_classes) batch_labels[:,i:i+1].view(-1) # (batch_size * 1) )) epoch_acc_dict[i] += categorical_accuracy(pred_class, batch_labels[:,i:i+1]) tn, fp, fn, tp = conf_matrix(pred_class, batch_labels[:,i:i+1]) all_tp[i] += tp all_tn[i] += tn all_fp[i] += fp all_fn[i] += fn all_labels[i] += list(batch_labels[:,i:i+1].view(-1).cpu().numpy()) all_preds[i] += list(pred_class) loss = sum(losses) epoch_loss += loss.item() num_batches += len(batch_features) bal_acc_dict = {} fscore_dict = {} for i in range(nc): # print(i,"pred, label", Counter(all_preds[i]), Counter(all_labels[i])) bal_acc = balanced_accuracy_score(all_labels[i], all_preds[i]) recall = all_tp[i] / (all_tp[i] + all_fn[i]) if all_tp[i] + all_fp[i] > 0: precision = all_tp[i] / (all_tp[i] + all_fp[i]) fscore = (2 * precision * recall) / (precision + recall) else: precision = -1 fscore = -1 bal_acc_dict[i] = ff(bal_acc) fscore_dict[i] = ff(fscore) epoch_acc_dict[i] = ff(epoch_acc_dict[i]/num_batches) return epoch_loss / num_batches, bal_acc_dict, fscore_dict, epoch_acc_dict
def evaluate(model, eval_dataset, eval_dataset_video, batch_size, criterion): # print("eval called") model = model.eval() epoch_loss = 0 epoch_acc_dict = {i: 0 for i in range(1)} all_tp, all_fp, all_tn, all_fn = 0, 0, 0, 0 num_batches = 0 all_labels = [] all_preds = [] for batch_features, video_batch_features, batch_labels in read_batches( eval_dataset, eval_dataset_video, batch_size, False, device): batch_features = batch_features.permute((0, 2, 1)) out, pred_classes = model.predict(video_batch_features, batch_features) outs = [out] losses = [] for i, out in enumerate(outs): losses.append( criterion( out, # (batch_size , num_classes) batch_labels.view(-1) # (batch_size * 1) )) epoch_acc_dict[i] += categorical_accuracy(pred_classes, batch_labels) loss = sum(losses) epoch_loss += loss.item() tn, fp, fn, tp = conf_matrix(pred_classes, batch_labels) all_tp += tp all_tn += tn all_fp += fp all_fn += fn num_batches += len(batch_features) all_labels += list(batch_labels.view(-1).cpu().numpy()) all_preds += list(pred_classes) # print("pred, label", Counter(all_preds), Counter(all_labels)) epoch_acc_dict = {k: v / num_batches for k, v in epoch_acc_dict.items()} bal_acc = ((all_tn / (all_tn + all_fp)) + (all_tp / (all_tp + all_fn))) / 2 recall = all_tp / (all_tp + all_fn) if all_tp + all_fp > 0: precision = all_tp / (all_tp + all_fp) fscore = (2 * precision * recall) / (precision + recall) else: precision = -1 fscore = -1 return epoch_loss / num_batches, ff( epoch_acc_dict[0]), ff(bal_acc), ff(fscore)