def train(params, m, datas): es = EarlyStopping(min_delta=params.min_delta, patience=params.patience) # optimizer ps = [p[1] for p in m.named_parameters() if 'discriminator' not in p[0]] print('Model parameter: {}'.format(sum(p.numel() for p in ps))) optimizer = optim.Adam(ps, lr=params.init_lr) if params.adv_training: dis_ps = [ p[1] for p in m.named_parameters() if 'discriminator' in p[0] ] dis_optimizer = optim.Adam(dis_ps, lr=params.init_lr) dis_enc_ps = [ p[1] for p in m.named_parameters() if 'encoder' in p[0] or 'embedding' in p[0] ] dis_enc_optimizer = optim.Adam(dis_enc_ps, lr=params.init_lr) # all training instances, split between 2 languages, right now the data are balanced n_batch = len(datas) * datas[0].train_size // params.bs if len( datas) * datas[0].train_size % params.bs == 0 else len( datas) * datas[0].train_size // params.bs + 1 data_idxs = {} for i, data in enumerate(datas): lang = data.vocab.lang data_idxs[lang] = list(range(data.train_size)) # number of iterations cur_it = 0 # write to tensorboard writer = SummaryWriter('./history/{}'.format( params.log_path)) if params.write_tfboard else None nll_dev = math.inf best_nll_dev = math.inf kld_dev = math.inf for i in range(params.ep): for lang in data_idxs: shuffle(data_idxs[lang]) for j in range(n_batch): if params.task == 'xl' or params.task == 'xl-adv': lang_idx = j % len(datas) data = datas[lang_idx] lang = data.vocab.lang train_idxs = data_idxs[lang][j // len(datas) * params.bs:(j // len(datas) + 1) * params.bs] elif params.task == 'mo': lang = params.langs[0] lang_idx = params.lang_dict[lang] data = datas[lang_idx] train_idxs = data_idxs[lang][j * params.bs:(j + 1) * params.bs] padded_batch, batch_lens = get_batch(train_idxs, data, data.train_idxs, data.train_lens, params.cuda) optimizer.zero_grad() if params.adv_training: dis_optimizer.zero_grad() dis_enc_optimizer.zero_grad() m.train() nll_batch, kld_batch, ls_dis, ls_enc = m(lang, padded_batch, batch_lens) cur_it += 1 loss_batch, alpha = calc_loss_batch(params, nll_batch, kld_batch, cur_it, n_batch) ''' # add adversarial loss to the encoder if cur_it > params.adv_ep * n_batch: loss_batch += ls_enc ''' if not params.adv_training: loss_batch.backward() optimizer.step() else: ls_dis = ls_dis.mean() ls_enc = ls_enc.mean() loss_batch = loss_batch + ls_dis + ls_enc loss_batch.backward() optimizer.step() dis_optimizer.step() dis_enc_optimizer.step() out_xling(i, j, n_batch, loss_batch, nll_batch, kld_batch, best_nll_dev, nll_dev, kld_dev, es.num_bad_epochs, ls_dis=ls_dis, ls_enc=ls_enc) update_tensorboard(writer, loss_batch, nll_batch, kld_batch, alpha, nll_dev, kld_dev, cur_it, ls_dis=ls_dis, ls_enc=ls_enc) if cur_it % params.VAL_EVERY == 0: sys.stdout.write('\n') sys.stdout.flush() # validation nll_dev, kld_dev = test(params, m, datas) if es.step(nll_dev): print('\nEarly Stoped.') return elif es.is_better(nll_dev, best_nll_dev): best_nll_dev = nll_dev # save model for lang in params.langs: lang_idx = params.lang_dict[lang] m.save_embedding(params, datas[lang_idx]) m.save_model(params, datas)
for i, (src_test, trg_test) in tqdm(enumerate(test_loader), total=int(len(test_set) / batch_size)): test_logit = model( Variable(src_test).to(device), Variable(trg_test).to(device)) trg_test = torch.cat((torch.index_select( trg_test, 1, torch.LongTensor(list(range(1, pad_len)))), torch.LongTensor( np.zeros([trg_test.shape[0], 1]))), dim=1) test_loss = loss_criterion( test_logit.contiguous().view(-1, vocab_size), Variable(trg_test).view(-1).to(device)) test_loss_sum += test_loss.item() del test_loss, test_logit print("Evaluation Loss", test_loss_sum) # es.new_loss(test_loss_sum) if es.step(test_loss_sum): print('Start over fitting') break # Save Model torch.save( model.state_dict(), open( os.path.join( 'checkpoint', 'new_simple_bar' + '_epoch_%d' % (epoch) + '.model'), 'wb'))
def one_fold(num_fold, train_index, dev_index): print("Training on fold:", num_fold) X_train, X_dev = [X[i] for i in train_index], [X[i] for i in dev_index] y_train, y_dev = y[train_index], y[dev_index] # construct data loader # for one fold, test data comes from k fold split. train_data_set = create_data.TrainDataSet(X_train, y_train, EMAI_PAD_LEN, SENT_PAD_LEN, word2id, emoji_st, use_unk=True) dev_data_set = create_data.TrainDataSet(X_dev, y_dev, EMAI_PAD_LEN, SENT_PAD_LEN, word2id, emoji_st, use_unk=True) dev_data_loader = DataLoader(dev_data_set, batch_size=BATCH_SIZE, shuffle=False) # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") final_pred_best = None # This is to prevent model diverge, once happen, retrain while True: is_diverged = False # Model is defined in HierarchicalPredictor if CONTINUE: model = torch.load(opt.out_path) else: model = HierarchicalAttPredictor(SENT_EMB_DIM, SENT_HIDDEN_SIZE, CTX_LSTM_DIM, num_of_vocab, SENT_PAD_LEN, id2word, USE_ELMO=True, ADD_LINEAR=False) model.load_embedding(emb) model.deepmoji_model.load_specific_weights( PRETRAINED_PATH, exclude_names=['output_layer']) model.cuda() optimizer = optim.Adam(model.parameters(), lr=learning_rate, amsgrad=True) scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=GAMMA) # loss_criterion_binary = nn.CrossEntropyLoss(weight=weight_list_binary) # if loss == 'focal': loss_criterion = FocalLoss(gamma=opt.focal) elif loss == 'ce': loss_criterion = nn.BCELoss() es = EarlyStopping(patience=EARLY_STOP_PATIENCE) final_pred_list_test = None result_print = {} for num_epoch in range(MAX_EPOCH): # to ensure shuffle at ever epoch train_data_loader = DataLoader(train_data_set, batch_size=BATCH_SIZE, shuffle=True) print('Begin training epoch:', num_epoch, end='...\t') sys.stdout.flush() # stepping scheduler scheduler.step(num_epoch) print('Current learning rate', scheduler.get_lr()) ## Training step train_loss = 0 model.train() for i, (a, a_len, emoji_a, e_c) \ in tqdm(enumerate(train_data_loader), total=len(train_data_set)/BATCH_SIZE): optimizer.zero_grad() e_c = e_c.type(torch.float) pred = model(a.cuda(), a_len, emoji_a.cuda()) loss_label = loss_criterion(pred.squeeze(1), e_c.view(-1).cuda()).cuda() # training trilogy loss_label.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP) optimizer.step() train_loss += loss_label.data.cpu().numpy() * a.shape[0] del pred, loss_label ## Evaluatation step model.eval() dev_loss = 0 # pred_list = [] for i, (a, a_len, emoji_a, e_c) in enumerate(dev_data_loader): with torch.no_grad(): e_c = e_c.type(torch.float) pred = model(a.cuda(), a_len, emoji_a.cuda()) loss_label = loss_criterion( pred.squeeze(1), e_c.view(-1).cuda()).cuda() dev_loss += loss_label.data.cpu().numpy() * a.shape[0] # pred_list.append(pred.data.cpu().numpy()) # gold_list.append(e_c.numpy()) del pred, loss_label print('Training loss:', train_loss / len(train_data_set), end='\t') print('Dev loss:', dev_loss / len(dev_data_set)) # print(classification_report(gold_list, pred_list, target_names=EMOS)) # get_metrics(pred_list, gold_list) # Gold Test testing print('Final test testing...') final_pred_list_test = [] model.eval() for i, (a, a_len, emoji_a) in enumerate(final_test_data_loader): with torch.no_grad(): pred = model(a.cuda(), a_len, emoji_a.cuda()) final_pred_list_test.append(pred.data.cpu().numpy()) del a, pred print("final_pred_list_test", len(final_pred_list_test)) final_pred_list_test = np.concatenate(final_pred_list_test, axis=0) final_pred_list_test = np.squeeze(final_pred_list_test, axis=1) print("final_pred_list_test_concat", len(final_pred_list_test)) accuracy, precision, recall, f1 = get_metrics( np.asarray(final_test_target_list), np.asarray(final_pred_list_test)) result_print.update( {num_epoch: [accuracy, precision, recall, f1]}) if dev_loss / len(dev_data_set) > 1.3 and num_epoch > 4: print("Model diverged, retry") is_diverged = True break if es.step(dev_loss): # overfitting print('overfitting, loading best model ...') break else: if es.is_best(): print('saving best model ...') if final_pred_best is not None: del final_pred_best final_pred_best = deepcopy(final_pred_list_test) else: print('not best model, ignoring ...') if final_pred_best is None: final_pred_best = deepcopy(final_pred_list_test) with open(result_path, 'wb') as w: pkl.dump(result_print, w) if is_diverged: print("Reinitialize model ...") del model continue real_test_results.append(np.asarray(final_pred_best)) # saving model for inference torch.save(model, opt.out_path) del model break
def train(params, m, datas): # early stopping es = EarlyStopping(mode='max', patience=params.cldc_patience) # set optimizer optimizer = get_optimizer(params, m) # training on one lang, and dev/test for another lang # get training train_lang, train_data = get_lang_data(params, datas, training=True) # get dev and test, dev is the same language as test test_lang, test_data = get_lang_data(params, datas) n_batch = train_data.train_size // params.cldc_bs if train_data.train_size % params.cldc_bs == 0 else train_data.train_size // params.cldc_bs + 1 # per category data_idxs = [ list(range(len(train_idx))) for train_idx in train_data.train_idxs ] # number of iterations cur_it = 0 # write to tensorboard writer = SummaryWriter('./history/{}'.format( params.log_path)) if params.write_tfboard else None # best xx bdev = 0 btest = 0 # current xx cdev = 0 ctest = 0 dev_class_acc = {} test_class_acc = {} dev_cm = None test_cm = None # early stopping warm up flag, start es after some iters es_flag = False for i in range(params.cldc_ep): for data_idx in data_idxs: shuffle(data_idx) for j in range(n_batch): train_idxs = [] for k, data_idx in enumerate(data_idxs): if j < n_batch - 1: train_idxs.append( data_idx[int(j * params.cldc_bs * train_data.train_prop[k]):int( (j + 1) * params.cldc_bs * train_data.train_prop[k])]) elif j == n_batch - 1: train_idxs.append(data_idx[int(j * params.cldc_bs * train_data.train_prop[k]):]) batch_train, batch_train_lens, batch_train_lb = get_batch( params, train_idxs, train_data.train_idxs, train_data.train_lens) optimizer.zero_grad() m.train() cldc_loss_batch, _, batch_pred = m(train_lang, batch_train, batch_train_lens, batch_train_lb) batch_acc, batch_acc_cls = get_classification_report( params, batch_train_lb.data.cpu().numpy(), batch_pred.data.cpu().numpy()) if cldc_loss_batch < params.cldc_lossth: es_flag = True cldc_loss_batch.backward() out_cldc(i, j, n_batch, cldc_loss_batch, batch_acc, batch_acc_cls, bdev, btest, cdev, ctest, es.num_bad_epochs) optimizer.step() cur_it += 1 update_tensorboard(writer, cldc_loss_batch, batch_acc, cdev, ctest, dev_class_acc, test_class_acc, cur_it) if cur_it % params.CLDC_VAL_EVERY == 0: sys.stdout.write('\n') sys.stdout.flush() # validation #cdev, dev_class_acc, dev_cm = test(params, m, test_data.dev_idxs, test_data.dev_lens, test_data.dev_size, test_data.dev_prop, test_lang, cm = True) cdev, dev_class_acc, dev_cm = test(params, m, train_data.dev_idxs, train_data.dev_lens, train_data.dev_size, train_data.dev_prop, train_lang, cm=True) ctest, test_class_acc, test_cm = test(params, m, test_data.test_idxs, test_data.test_lens, test_data.test_size, test_data.test_prop, test_lang, cm=True) print(dev_cm) print(test_cm) if es.step(cdev): print('\nEarly Stoped.') return elif es.is_better(cdev, bdev): bdev = cdev btest = ctest #save_model(params, m) # reset bad epochs if not es_flag: es.num_bad_epochs = 0
def main(params, m, data): # early stopping es = EarlyStopping(mode='max', patience=params.patience) # set optimizer optimizer = get_optimizer(params, m) n_batch = data.train_size // params.bs if data.train_size % params.bs == 0 else data.train_size // params.bs + 1 # per category data_idxs = [list(range(len(train_idx))) for train_idx in data.train_idxs] # number of iterations cur_it = 0 # best xx bdev = 0 btest = 0 # current xx cdev = 0 ctest = 0 dev_class_acc = {} test_class_acc = {} dev_cm = None test_cm = None # early stopping warm up flag, start es after some iters es_flag = False for i in range(params.ep): # self-training if params.self_train or i >= params.semi_warm_up: params.self_train = True first_update = (i == params.semi_warm_up) # only for zero-shot if first_update: es.num_bad_epochs = 0 es.best = 0 bdev = 0 btest = 0 data = self_train_merge_data(params, m, es, data, first=first_update) n_batch = data.self_train_size // params.bs if data.self_train_size % params.bs == 0 else data.self_train_size // params.bs + 1 # per category data_idxs = [ list(range(len(train_idx))) for train_idx in data.self_train_idxs ] for data_idx in data_idxs: shuffle(data_idx) for j in range(n_batch): train_idxs = [] for k, data_idx in enumerate(data_idxs): if params.self_train: train_prop = data.self_train_prop else: train_prop = data.train_prop if j < n_batch - 1: train_idxs.append( data_idx[int(j * params.bs * train_prop[k]):int((j + 1) * params.bs * train_prop[k])]) elif j == n_batch - 1: train_idxs.append(data_idx[int(j * params.bs * train_prop[k]):]) if params.self_train: batch_train, _, batch_train_lb = get_batch( params, train_idxs, data.self_train_idxs, data.self_train_lens) else: batch_train, _, batch_train_lb = get_batch( params, train_idxs, data.train_idxs, data.train_lens) optimizer.zero_grad() m.train() loss_batch, logits = m(batch_train, labels=batch_train_lb) batch_pred = torch.argmax(logits, dim=1) batch_acc, batch_acc_cls = get_classification_report( params, batch_train_lb.data.cpu().numpy(), batch_pred.data.cpu().numpy()) if loss_batch < params.lossth: es_flag = True loss_batch.backward() out_cldc(i, j, n_batch, loss_batch, batch_acc, batch_acc_cls, bdev, btest, cdev, ctest, es.num_bad_epochs) optimizer.step() cur_it += 1 sys.stdout.write('\n') sys.stdout.flush() # validation cdev, dev_class_acc, dev_cm = test(params, m, data.dev_idxs, data.dev_lens, data.dev_size, data.dev_prop, cm=True) ctest, test_class_acc, test_cm = test(params, m, data.test_idxs, data.test_lens, data.test_size, data.test_prop, cm=True) print(dev_cm) print(test_cm) if es.step(cdev): print('\nEarly Stoped.') return elif es.is_better(cdev, bdev): bdev = cdev btest = ctest # reset bad epochs if not es_flag: es.num_bad_epochs = 0
def one_fold(num_fold, train_index, dev_index): print("Training on fold:", num_fold) X_train, X_dev = [X[i] for i in train_index], [X[i] for i in dev_index] y_train, y_dev = y[train_index], y[dev_index] # construct data loader train_data_set = TrainDataSet(X_train, y_train, CONV_PAD_LEN, SENT_PAD_LEN, word2id, use_unk=True) dev_data_set = TrainDataSet(X_dev, y_dev, CONV_PAD_LEN, SENT_PAD_LEN, word2id, use_unk=True) dev_data_loader = DataLoader(dev_data_set, batch_size=BATCH_SIZE, shuffle=False) # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") pred_list_test_best = None final_pred_best = None # This is to prevent model diverge, once happen, retrain while True: is_diverged = False model = HierarchicalPredictor(SENT_EMB_DIM, SENT_HIDDEN_SIZE, num_of_vocab, USE_ELMO=True, ADD_LINEAR=False) model.load_embedding(emb) model.cuda() # model = nn.DataParallel(model) # model.to(device) optimizer = optim.Adam(model.parameters(), lr=learning_rate, amsgrad=True) # # optimizer = optim.SGD(model.parameters(), lr=learning_rate) scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=opt.gamma) if opt.w == 1: weight_list = [0.3, 0.3, 0.3, 1.7] weight_list_binary = [0.3, 1.7] elif opt.w == 2: weight_list = [ 0.3198680179, 0.246494733, 0.2484349259, 1.74527696 ] weight_list_binary = [2 - weight_list[-1], weight_list[-1]] weight_list = [x**FLAT for x in weight_list] weight_label = torch.Tensor(weight_list).cuda() weight_list_binary = [x**FLAT for x in weight_list_binary] weight_binary = torch.Tensor(weight_list_binary).cuda() print('classification reweight: ', weight_list) print('binary loss reweight = weight_list_binary', weight_list_binary) # loss_criterion_binary = nn.CrossEntropyLoss(weight=weight_list_binary) # if opt.loss == 'focal': loss_criterion = FocalLoss(gamma=opt.focal, reduce=False) loss_criterion_binary = FocalLoss(gamma=opt.focal, reduce=False) # elif opt.loss == 'ce': loss_criterion = nn.CrossEntropyLoss(reduce=False) loss_criterion_binary = nn.CrossEntropyLoss(reduce=False) # loss_criterion_emo_only = nn.MSELoss() # es = EarlyStopping(min_delta=0.005, patience=EARLY_STOP_PATIENCE) es = EarlyStopping(patience=EARLY_STOP_PATIENCE) # best_model = None final_pred_list_test = None pred_list_test = None for num_epoch in range(MAX_EPOCH): # to ensure shuffle at ever epoch train_data_loader = DataLoader(train_data_set, batch_size=BATCH_SIZE, shuffle=True) print('Begin training epoch:', num_epoch, end='...\t') sys.stdout.flush() # stepping scheduler scheduler.step(num_epoch) print('Current learning rate', scheduler.get_lr()) train_loss = 0 model.train() for i, (a, a_len, emoji_a, e_c, e_c_binary, e_c_emo) \ in tqdm(enumerate(train_data_loader), total=len(train_data_set)/BATCH_SIZE): optimizer.zero_grad() elmo_a = elmo_encode(a) pred, pred2, pred3 = model(a.cuda(), a_len, emoji_a.cuda(), elmo_a) loss_label = loss_criterion(pred, e_c.view(-1).cuda()).cuda() loss_label = torch.matmul(torch.gather(weight_label, 0, e_c.view(-1).cuda()), loss_label) / \ e_c.view(-1).shape[0] loss_binary = loss_criterion_binary( pred2, e_c_binary.view(-1).cuda()).cuda() loss_binary = torch.matmul( torch.gather(weight_binary, 0, e_c_binary.view(-1).cuda()), loss_binary) / e_c.view(-1).shape[0] loss_emo = loss_criterion_emo_only(pred3, e_c_emo.cuda()) loss = (loss_label + LAMBDA1 * loss_binary + LAMBDA2 * loss_emo) / float(1 + LAMBDA1 + LAMBDA2) # loss = torch.matmul(torch.gather(weight, 0, trg.view(-1).cuda()), loss) / trg.view(-1).shape[0] # training trilogy loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP) optimizer.step() train_loss += loss.data.cpu().numpy() * a.shape[0] del pred, loss, elmo_a, e_c_emo, loss_binary, loss_label, loss_emo # Evaluate model.eval() dev_loss = 0 # pred_list = [] # gold_list = [] for i, (a, a_len, emoji_a, e_c, e_c_binary, e_c_emo) \ in enumerate(dev_data_loader): with torch.no_grad(): elmo_a = elmo_encode(a) pred, pred2, pred3 = model(a.cuda(), a_len, emoji_a.cuda(), elmo_a) loss_label = loss_criterion( pred, e_c.view(-1).cuda()).cuda() loss_label = torch.matmul( torch.gather(weight_label, 0, e_c.view(-1).cuda()), loss_label) / e_c.view(-1).shape[0] loss_binary = loss_criterion_binary( pred2, e_c_binary.view(-1).cuda()).cuda() loss_binary = torch.matmul( torch.gather(weight_binary, 0, e_c_binary.view(-1).cuda()), loss_binary) / e_c.view(-1).shape[0] loss_emo = loss_criterion_emo_only( pred3, e_c_emo.cuda()) loss = (loss_label + LAMBDA1 * loss_binary + LAMBDA2 * loss_emo) / float(1 + LAMBDA1 + LAMBDA2) dev_loss += loss.data.cpu().numpy() * a.shape[0] # pred_list.append(pred.data.cpu().numpy()) # gold_list.append(e_c.numpy()) del pred, loss, elmo_a, e_c_emo, loss_binary, loss_label, loss_emo print('Training loss:', train_loss / len(train_data_set), end='\t') print('Dev loss:', dev_loss / len(dev_data_set)) # print(classification_report(gold_list, pred_list, target_names=EMOS)) # get_metrics(pred_list, gold_list) if dev_loss / len(dev_data_set) > 1.3 and num_epoch > 4: print("Model diverged, retry") is_diverged = True break if es.step(dev_loss): # overfitting print('overfitting, loading best model ...') break else: if es.is_best(): print('saving best model ...') if final_pred_best is not None: del final_pred_best final_pred_best = deepcopy(final_pred_list_test) if pred_list_test_best is not None: del pred_list_test_best pred_list_test_best = deepcopy(pred_list_test) else: print('not best model, ignoring ...') if final_pred_best is None: final_pred_best = deepcopy(final_pred_list_test) if pred_list_test_best is None: pred_list_test_best = deepcopy(pred_list_test) # Gold Dev testing... print('Gold Dev testing....') pred_list_test = [] model.eval() for i, (a, a_len, emoji_a) in enumerate(gold_dev_data_loader): with torch.no_grad(): elmo_a = elmo_encode(a) # , __id2word=ex_id2word pred, _, _ = model(a.cuda(), a_len, emoji_a.cuda(), elmo_a) pred_list_test.append(pred.data.cpu().numpy()) del elmo_a, a, pred pred_list_test = np.argmax(np.concatenate(pred_list_test, axis=0), axis=1) # get_metrics(load_dev_labels('data/dev.txt'), pred_list_test) # Testing print('Gold test testing...') final_pred_list_test = [] model.eval() for i, (a, a_len, emoji_a) in enumerate(test_data_loader): with torch.no_grad(): elmo_a = elmo_encode(a) # , __id2word=ex_id2word pred, _, _ = model(a.cuda(), a_len, emoji_a.cuda(), elmo_a) final_pred_list_test.append(pred.data.cpu().numpy()) del elmo_a, a, pred final_pred_list_test = np.argmax(np.concatenate( final_pred_list_test, axis=0), axis=1) # get_metrics(load_dev_labels('data/test.txt'), final_pred_list_test) if is_diverged: print("Reinitialize model ...") del model continue all_fold_results.append(pred_list_test_best) real_test_results.append(final_pred_best) del model break
def one_fold(num_fold, train_index, dev_index): print("Training on fold:", num_fold) X_train, X_dev = [X[i] for i in train_index], [X[i] for i in dev_index] y_train, y_dev = y[train_index], y[dev_index] # construct data loader train_data_set = DataSet(X_train, y_train, SENT_PAD_LEN) train_data_loader = DataLoader(train_data_set, batch_size=BATCH_SIZE, shuffle=True) dev_data_set = DataSet(X_dev, y_dev, SENT_PAD_LEN) dev_data_loader = DataLoader(dev_data_set, batch_size=BATCH_SIZE, shuffle=False) gradient_accumulation_steps = 1 num_train_steps = int( len(train_data_set) / BATCH_SIZE / gradient_accumulation_steps * MAX_EPOCH) pred_list_test_best = None final_pred_best = None # This is to prevent model diverge, once happen, retrain while True: is_diverged = False model = BERT_classifer.from_pretrained(BERT_MODEL) model.add_output_layer(BERT_MODEL, NUM_EMO) model = nn.DataParallel(model) if HALF_PRECISION: # model = network_to_half(model) model.half() model.to(device) #model.cpu() # BERT optimizer param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'gamma', 'beta'] optimizer_grouped_parameters = [ {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.01}, {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.0} ] optimizer = BertAdam(optimizer_grouped_parameters, lr=learning_rate, warmup=0.1, t_total=num_train_steps) if opt.w == 1: weight_list = [0.3, 0.3, 0.3, 1.7] weight_list_binary = [2 - weight_list[-1], weight_list[-1]] elif opt.w == 2: weight_list = [0.3198680179, 0.246494733, 0.2484349259, 1.74527696] weight_list_binary = [2 - weight_list[-1], weight_list[-1]] weight_list = [x**FLAT for x in weight_list] weight_label = torch.Tensor(weight_list).to(device) weight_list_binary = [x**FLAT for x in weight_list_binary] weight_binary = torch.Tensor(weight_list_binary).to(device) print('binary loss reweight = weight_list_binary', weight_list_binary) # loss_criterion_binary = nn.CrossEntropyLoss(weight=weight_list_binary) # if opt.loss == 'focal': loss_criterion = FocalLoss(gamma=opt.focal, reduce=False) loss_criterion_binary = FocalLoss(gamma=opt.focal, reduce=False) # elif opt.loss == 'ce': loss_criterion = nn.CrossEntropyLoss(reduce=False) loss_criterion_binary = nn.CrossEntropyLoss(reduce=False) # loss_criterion_emo_only = nn.MSELoss() # es = EarlyStopping(min_delta=0.005, patience=EARLY_STOP_PATIENCE) es = EarlyStopping(patience=EARLY_STOP_PATIENCE) final_pred_best = None final_pred_list_test = None pred_list_test = None for num_epoch in range(MAX_EPOCH): print('Begin training epoch:', num_epoch) sys.stdout.flush() train_loss = 0 model.train() for i, (tokens, masks, segments, e_c, e_c_binary, e_c_emo) in tqdm(enumerate(train_data_loader), total=len(train_data_set)/BATCH_SIZE): optimizer.zero_grad() if USE_TOKEN_TYPE: pred, pred2, pred3 = model(tokens.to(device), masks.to(device), segments.to(device)) else: pred, pred2, pred3 = model(tokens.to(device), masks.to(device)) loss_label = loss_criterion(pred, e_c.view(-1).to(device)).to(device) loss_label = torch.matmul(torch.gather(weight_label, 0, e_c.view(-1).to(device)), loss_label) / \ e_c.view(-1).shape[0] loss_binary = loss_criterion_binary(pred2, e_c_binary.view(-1).to(device)).to(device) loss_binary = torch.matmul(torch.gather(weight_binary, 0, e_c_binary.view(-1).to(device)), loss_binary) / e_c.view(-1).shape[0] loss_emo = loss_criterion_emo_only(pred3, e_c_emo.to(device)) loss = (loss_label + LAMBDA1 * loss_binary + LAMBDA2 * loss_emo) / float(1 + LAMBDA1 + LAMBDA2) # training trilogy loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP) optimizer.step() train_loss += loss.data.cpu().numpy() * tokens.shape[0] del loss, pred # Evaluate model.eval() dev_loss = 0 # pred_list = [] # gold_list = [] for i, (tokens, masks, segments, e_c, e_c_binary, e_c_emo) in enumerate(dev_data_loader): with torch.no_grad(): if USE_TOKEN_TYPE: pred, pred2, pred3 = model(tokens.to(device), masks.to(device), segments.to(device)) else: pred, pred2, pred3 = model(tokens.to(device), masks.to(device)) loss_label = loss_criterion(pred, e_c.view(-1).to(device)).to(device) loss_label = torch.matmul(torch.gather(weight_label, 0, e_c.view(-1).to(device)), loss_label) / \ e_c.view(-1).shape[0] loss_binary = loss_criterion_binary(pred2, e_c_binary.view(-1).to(device)).to(device) loss_binary = torch.matmul(torch.gather(weight_binary, 0, e_c_binary.view(-1).to(device)), loss_binary) / e_c.view(-1).shape[0] loss_emo = loss_criterion_emo_only(pred3, e_c_emo.to(device)) loss = (loss_label + LAMBDA1 * loss_binary + LAMBDA2 * loss_emo) / float(1 + LAMBDA1 + LAMBDA2) dev_loss += loss.data.cpu().numpy() * tokens.shape[0] # pred_list.append(pred.data.cpu().numpy()) # gold_list.append(e_c.numpy()) del pred, loss # pred_list = np.argmax(np.concatenate(pred_list, axis=0), axis=1) # gold_list = np.concatenate(gold_list, axis=0) print('Training loss:', train_loss / len(train_data_set), end='\t') print('Dev loss:', dev_loss / len(dev_data_set)) # print(classification_report(gold_list, pred_list, target_names=EMOS)) # get_metrics(pred_list, gold_list) # checking diverge if dev_loss/len(dev_data_set) > 1.3 and num_epoch > 4: print("Model diverged, retry") is_diverged = True break if es.step(dev_loss): # overfitting print('overfitting, loading best model ...') if num_epoch == 1: is_diverged = True final_pred_best = deepcopy(final_pred_list_test) pred_list_test_best = deepcopy(pred_list_test) break else: if es.is_best(): print('saving best model ...') if final_pred_best is not None: del final_pred_best final_pred_best = deepcopy(final_pred_list_test) if pred_list_test_best is not None: del pred_list_test_best pred_list_test_best = deepcopy(pred_list_test) else: print('not best model, ignoring ...') if final_pred_best is None: final_pred_best = deepcopy(final_pred_list_test) if pred_list_test_best is None: pred_list_test_best = deepcopy(pred_list_test) print('Gold Dev ...') pred_list_test = [] model.eval() for i, (tokens, masks, segments, e_c, e_c_binary, e_c_emo) in enumerate(gold_dev_data_loader): with torch.no_grad(): if USE_TOKEN_TYPE: pred, _, _ = model(tokens.to(device), masks.to(device), segments.to(device)) else: pred, _, _ = model(tokens.to(device), masks.to(device)) pred_list_test.append(pred.data.cpu().numpy()) pred_list_test = np.argmax(np.concatenate(pred_list_test, axis=0), axis=1) # get_metrics(load_dev_labels('data/dev.txt'), pred_list_test) print('Gold Test ...') final_pred_list_test = [] model.eval() for i, (tokens, masks, segments, e_c, e_c_binary, e_c_emo) in enumerate(gold_test_data_loader): with torch.no_grad(): if USE_TOKEN_TYPE: pred, _, _ = model(tokens.to(device), masks.to(device), segments.to(device)) else: pred, _, _ = model(tokens.to(device), masks.to(device)) final_pred_list_test.append(pred.data.cpu().numpy()) final_pred_list_test = np.argmax(np.concatenate(final_pred_list_test, axis=0), axis=1) # get_metrics(load_dev_labels('data/test.txt'), final_pred_list_test) if is_diverged: print("Reinitialize model ...") del model continue all_fold_results.append(pred_list_test_best) real_test_results.append(final_pred_best) del model break
def train(X_train, y_train, X_dev, y_dev, X_test, y_test): num_labels = NUM_EMO vocab_size = VOCAB_SIZE print('NUM of VOCAB' + str(vocab_size)) train_data = EmotionDataLoader(X_train, y_train, PAD_LEN) train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True) dev_data = EmotionDataLoader(X_dev, y_dev, PAD_LEN) dev_loader = DataLoader(dev_data, batch_size=int(BATCH_SIZE/3)+2, shuffle=False) test_data = EmotionDataLoader(X_test, y_test, PAD_LEN) test_loader = DataLoader(test_data, batch_size=int(BATCH_SIZE/3)+2, shuffle=False) model = AttentionLSTMClassifier(EMBEDDING_DIM, HIDDEN_DIM, vocab_size, num_labels, BATCH_SIZE, att_mode=opt.attention, soft_last=False) model.load_embedding(tokenizer.get_embeddings()) # multi-GPU # model = nn.DataParallel(model) model.cuda() loss_criterion = nn.CrossEntropyLoss() # optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE) es = EarlyStopping(patience=PATIENCE) old_model = None for epoch in range(1, 300): print('Epoch: ' + str(epoch) + '===================================') train_loss = 0 model.train() for i, (data, seq_len, label) in tqdm(enumerate(train_loader), total=len(train_data)/BATCH_SIZE): optimizer.zero_grad() y_pred = model(data.cuda(), seq_len) loss = loss_criterion(y_pred, label.view(-1).cuda()) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), CLIPS) optimizer.step() train_loss += loss.data.cpu().numpy() * data.shape[0] del y_pred, loss test_loss = 0 model.eval() for _, (_data, _seq_len, _label) in enumerate(dev_loader): with torch.no_grad(): y_pred = model(_data.cuda(), _seq_len) loss = loss_criterion(y_pred, _label.view(-1).cuda()) test_loss += loss.data.cpu().numpy() * _data.shape[0] del y_pred, loss print("Train Loss: " + str(train_loss / len(train_data)) + \ " Evaluation: " + str(test_loss / len(dev_data))) if es.step(test_loss): # overfitting del model print('overfitting, loading best model ...') model = old_model break else: if es.is_best(): if old_model is not None: del old_model print('saving best model ...') old_model = deepcopy(model) else: print('not best model, ignoring ...') if old_model is None: old_model = deepcopy(model) with open(f'lstm_{opt.dataset}_model.pt', 'bw') as f: torch.save(model.state_dict(), f) pred_list = [] model.eval() for _, (_data, _seq_len, _label) in enumerate(test_loader): with torch.no_grad(): y_pred = model(_data.cuda(), _seq_len) pred_list.append(y_pred.data.cpu().numpy()) # x[np.where( x > 3.0 )] del y_pred pred_list = np.argmax(np.concatenate(pred_list, axis=0), axis=1) return pred_list
class Trainer(): def __init__(self, cfg, writer, img_writer, logger, run_id): # Copy shared config fields if "monodepth_options" in cfg: cfg["data"].update(cfg["monodepth_options"]) cfg["model"].update(cfg["monodepth_options"]) cfg["training"]["monodepth_loss"].update(cfg["monodepth_options"]) if "generated_depth_dir" in cfg["data"]: dataset_name = f"{cfg['data']['dataset']}_" \ f"{cfg['data']['width']}x{cfg['data']['height']}" depth_teacher = cfg["data"].get("depth_teacher", None) assert not (depth_teacher and cfg['model'].get('detph_estimator_weights') is not None) if depth_teacher is not None: cfg["data"]["generated_depth_dir"] += dataset_name + "/" + depth_teacher + "/" else: cfg["data"]["generated_depth_dir"] += dataset_name + "/" + cfg['model']['depth_estimator_weights'] + "/" # Setup seeds setup_seeds(cfg.get("seed", 1337)) if cfg["data"]["dataset_seed"] == "same": cfg["data"]["dataset_seed"] = cfg["seed"] # Setup device torch.backends.cudnn.benchmark = cfg["training"].get("benchmark", True) self.cfg = cfg self.writer = writer self.img_writer = img_writer self.logger = logger self.run_id = run_id self.mIoU = 0 self.fwAcc = 0 self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.setup_segmentation_unlabeled() self.unlabeled_require_depth = (self.cfg["training"]["unlabeled_segmentation"] is not None and (self.cfg["training"]["unlabeled_segmentation"]["mix_mask"] == "depth" or self.cfg["training"]["unlabeled_segmentation"]["mix_mask"] == "depthcomp" or self.cfg["training"]["unlabeled_segmentation"]["mix_mask"] == "depthhist")) # Prepare depth estimates do_precalculate_depth = self.cfg["training"]["segmentation_lambda"] != 0 and self.unlabeled_require_depth and \ self.cfg['model']['segmentation_name'] != 'mtl_pad' use_depth_teacher = cfg["data"].get("depth_teacher", None) is not None if do_precalculate_depth or use_depth_teacher: assert not (do_precalculate_depth and use_depth_teacher) if not self.cfg["training"].get("disable_depth_estimator", False): print("Prepare depth estimates") depth_estimator = DepthEstimator(cfg) depth_estimator.prepare_depth_estimates() del depth_estimator torch.cuda.empty_cache() else: self.cfg["data"]["generated_depth_dir"] = None # Setup Dataloader load_labels, load_sequence = True, True if self.cfg["training"]["monodepth_lambda"] == 0: load_sequence = False if self.cfg["training"]["segmentation_lambda"] == 0: load_labels = False train_data_cfg = deepcopy(self.cfg["data"]) if not do_precalculate_depth and not use_depth_teacher: train_data_cfg["generated_depth_dir"] = None self.train_loader = build_loader(train_data_cfg, "train", load_labels=load_labels, load_sequence=load_sequence) if self.cfg["training"].get("minimize_entropy_unlabeled", False) or self.enable_unlabled_segmentation: unlabeled_segmentation_cfg = deepcopy(self.cfg["data"]) if not self.only_unlabeled and self.mix_use_gt: unlabeled_segmentation_cfg["load_onehot"] = True if self.only_unlabeled: unlabeled_segmentation_cfg.update({"load_unlabeled": True, "load_labeled": False}) elif self.only_labeled: unlabeled_segmentation_cfg.update({"load_unlabeled": False, "load_labeled": True}) else: unlabeled_segmentation_cfg.update({"load_unlabeled": True, "load_labeled": True}) if self.mix_video: assert not self.mix_use_gt and not self.only_labeled and not self.only_unlabeled, \ "Video sample indices are not compatible with non-video indices." unlabeled_segmentation_cfg.update({"only_sequences_with_segmentation": not self.mix_video, "restrict_to_subset": None}) self.unlabeled_loader = build_loader(unlabeled_segmentation_cfg, "train", load_labels=load_labels if not self.mix_video else False, load_sequence=load_sequence) else: self.unlabeled_loader = None self.val_loader = build_loader(self.cfg["data"], "val", load_labels=load_labels, load_sequence=load_sequence) self.n_classes = self.train_loader.n_classes # monodepth dataloader settings uses drop_last=True and shuffle=True even for val self.train_data_loader = data.DataLoader( self.train_loader, batch_size=self.cfg["training"]["batch_size"], num_workers=self.cfg["training"]["n_workers"], shuffle=self.cfg["data"]["shuffle_trainset"], pin_memory=True, # Setting to false will cause crash at the end of epoch drop_last=True, ) if self.unlabeled_loader is not None: self.unlabeled_data_loader = infinite_iterator(data.DataLoader( self.unlabeled_loader, batch_size=self.cfg["training"]["batch_size"], num_workers=self.cfg["training"]["n_workers"], shuffle=self.cfg["data"]["shuffle_trainset"], pin_memory=True, # Setting to false will cause crash at the end of epoch drop_last=True, )) self.val_batch_size = self.cfg["training"]["val_batch_size"] self.val_data_loader = data.DataLoader( self.val_loader, batch_size=self.val_batch_size, num_workers=self.cfg["training"]["n_workers"], pin_memory=True, # If using a dataset with odd number of samples (CamVid), the memory consumption suddenly increases for the # last batch. This can be circumvented by dropping the last batch. Only do that if it is necessary for your # system as it will result in an incomplete validation set. # drop_last=True, ) # Setup Model self.model = get_model(cfg["model"], self.n_classes).to(self.device) # print(self.model) assert not (self.enable_unlabled_segmentation and self.cfg["training"]["save_monodepth_ema"]) if self.enable_unlabled_segmentation and not self.only_labeled: print("Create segmentation ema model.") self.ema_model = self.create_ema_model(self.model).to(self.device) elif self.cfg["training"]["save_monodepth_ema"]: print("Create depth ema model.") # TODO: Try to remove unnecessary components and fit into gpu for better performance self.ema_model = self.create_ema_model(self.model) # .to(self.device) else: self.ema_model = None # Setup optimizer, lr_scheduler and loss function optimizer_cls = get_optimizer(cfg) optimizer_params = {k: v for k, v in cfg["training"]["optimizer"].items() if k not in ["name", "backbone_lr", "pose_lr", "depth_lr", "segmentation_lr"]} train_params = get_train_params(self.model, self.cfg) self.optimizer = optimizer_cls(train_params, **optimizer_params) self.scheduler = get_scheduler(self.optimizer, self.cfg["training"]["lr_schedule"]) # Creates a GradScaler once at the beginning of training. self.scaler = GradScaler(enabled=self.cfg["training"]["amp"]) self.loss_fn = get_segmentation_loss_function(self.cfg) self.monodepth_loss_calculator_train = get_monodepth_loss(self.cfg, is_train=True) self.monodepth_loss_calculator_val = get_monodepth_loss(self.cfg, is_train=False, batch_size=self.val_batch_size) if cfg["training"]["early_stopping"] is None: logger.info("Using No Early Stopping") self.earlyStopping = None else: self.earlyStopping = EarlyStopping( patience=round(cfg["training"]["early_stopping"]["patience"] / cfg["training"]["val_interval"]), min_delta=cfg["training"]["early_stopping"]["min_delta"], cumulative_delta=cfg["training"]["early_stopping"]["cum_delta"], logger=logger ) def extract_monodepth_ema_params(self, model, ema_model): model_names = ["depth"] if not self.cfg["model"]["freeze_backbone"]: model_names.append("encoder") return extract_ema_params(model, ema_model, model_names) def extract_pad_ema_params(self, model, ema_model): model_names = ["depth", "encoder", "mtl_decoder"] return extract_ema_params(model, ema_model, model_names) def create_ema_model(self, model): ema_cfg = deepcopy(self.cfg["model"]) ema_cfg["disable_pose"] = True ema_model = get_model(ema_cfg, self.n_classes) if self.cfg["training"]["save_monodepth_ema"]: mp, mcp = self.extract_monodepth_ema_params(model, ema_model) elif self.cfg['model']['segmentation_name'] == 'mtl_pad': mp, mcp = self.extract_pad_ema_params(model, ema_model) else: mp, mcp = list(model.parameters()), list(ema_model.parameters()) for param in mcp: param.detach_() assert len(mp) == len(mcp), f"len(mp)={len(mp)}; len(mcp)={len(mcp)}" n = len(mp) for i in range(0, n): mcp[i].data[:] = mp[i].to(mcp[i].device, non_blocking=True).data[:].clone() return ema_model def update_ema_variables(self, ema_model, model, alpha_teacher, iteration): if self.cfg["training"]["save_monodepth_ema"]: model_params, ema_params = self.extract_monodepth_ema_params(model, ema_model) elif self.cfg['model']['segmentation_name'] == 'mtl_pad': model_params, ema_params = self.extract_pad_ema_params(model, ema_model) else: model_params, ema_params = model.parameters(), ema_model.parameters() # Use the "true" average until the exponential average is more correct alpha_teacher = min(1 - 1 / (iteration + 1), alpha_teacher) for ema_param, param in zip(ema_params, model_params): ema_param.data[:] = alpha_teacher * ema_param[:].data[:] + \ (1 - alpha_teacher) * param.to(ema_param.device, non_blocking=True)[:].data[:] return ema_model def save_resume(self, step): if self.ema_model is not None: raise NotImplementedError("ema model not supported") state = { "epoch": step + 1, "model_state": self.model.state_dict(), "optimizer_state": self.optimizer.state_dict(), "scheduler_state": self.scheduler.state_dict(), "best_iou": self.best_iou, } save_path = os.path.join( self.writer.file_writer.get_logdir(), "best_model.pkl" ) torch.save(state, save_path) return save_path def save_monodepth_models(self): if self.cfg["training"]["save_monodepth_ema"]: print("Save ema monodepth models.") assert self.ema_model is not None model_to_save = self.ema_model else: model_to_save = self.model models = ["depth", "pose_encoder", "pose"] if not self.cfg["model"]["freeze_backbone"]: models.append("encoder") for model_name in models: save_path = os.path.join(self.writer.file_writer.get_logdir(), "{}.pth".format(model_name)) to_save = model_to_save.models[model_name].state_dict() torch.save(to_save, save_path) def load_resume(self, strict=True, load_model_only=False): if os.path.isfile(self.cfg["training"]["resume"]): self.logger.info( "Loading model and optimizer from checkpoint '{}'".format(self.cfg["training"]["resume"]) ) checkpoint = torch.load(self.cfg["training"]["resume"]) self.model.load_state_dict(checkpoint["model_state"], strict=strict) if not load_model_only: self.optimizer.load_state_dict(checkpoint["optimizer_state"]) self.scheduler.load_state_dict(checkpoint["scheduler_state"]) self.start_iter = checkpoint["epoch"] self.best_iou = checkpoint["best_iou"] self.logger.info( "Loaded checkpoint '{}' (iter {})".format( self.cfg["training"]["resume"], checkpoint["epoch"] ) ) else: self.logger.info("No checkpoint found at '{}'".format(self.cfg["training"]["resume"])) def tensorboard_training_images(self): num_saved = 0 if self.cfg["training"]["n_tensorboard_trainimgs"] == 0: return for inputs in self.train_data_loader: images = inputs[("color_aug", 0, 0)] labels = inputs["lbl"] for img, label in zip(images.numpy(), labels.numpy()): if num_saved < self.cfg["training"]["n_tensorboard_trainimgs"]: num_saved += 1 self.img_writer.add_image( "trainset_{}/{}_0image".format(self.run_id.replace('/', '_'), num_saved), img, global_step=0) colored_image = self.val_loader.decode_segmap_tocolor(label) self.img_writer.add_image( "trainset_{}/{}_1ground_truth".format(self.run_id.replace('/', '_'), num_saved), colored_image, global_step=0, dataformats="HWC") if num_saved >= self.cfg["training"]["n_tensorboard_trainimgs"]: break def _train_batchnorm(self, model, train, only_encoder=False): if only_encoder: modules = model.models["encoder"].modules() else: modules = model.modules() for m in modules: if isinstance(m, nn.BatchNorm2d): m.train(train) def train_step(self, inputs, step): self.model.train() if self.ema_model is not None: self.ema_model.train() for k, v in inputs.items(): if torch.is_tensor(v): inputs[k] = v.to(self.device, non_blocking=True) if self.enable_unlabled_segmentation: unlabeled_inputs = self.unlabeled_data_loader.__next__() for k in unlabeled_inputs.keys(): if "color_aug" in k or "K" in k or "inv_K" in k or "color" in k or k in ["onehot_lbl", "pseudo_depth"]: # print(f"Move {k} to gpu.") unlabeled_inputs[k] = unlabeled_inputs[k].to(self.device, non_blocking=True) self.optimizer.zero_grad() segmentation_loss = torch.tensor(0) segmentation_total_loss = torch.tensor(0) mono_loss = torch.tensor(0) feat_dist_loss = torch.tensor(0) mono_total_loss = torch.tensor(0) if self.cfg["model"].get("freeze_backbone_bn", False): self._train_batchnorm(self.model, False, only_encoder=True) with autocast(enabled=self.cfg["training"]["amp"]): outputs = self.model(inputs) # Train monodepth if self.cfg["training"]["monodepth_lambda"] > 0: for k, v in outputs.items(): if "depth" in k or "cam_T_cam" in k: outputs[k] = v.to(torch.float32) self.monodepth_loss_calculator_train.generate_images_pred(inputs, outputs) mono_losses = self.monodepth_loss_calculator_train.compute_losses(inputs, outputs) mono_lambda = self.cfg["training"]["monodepth_lambda"] mono_loss = mono_lambda * mono_losses["loss"] feat_dist_lambda = self.cfg["training"]["feat_dist_lambda"] if feat_dist_lambda > 0: feat_dist = torch.dist(outputs["encoder_features"], outputs["imnet_features"], p=2) feat_dist_loss = feat_dist_lambda * feat_dist mono_total_loss = mono_loss + feat_dist_loss self.scaler.scale(mono_total_loss).backward(retain_graph=True) # Train depth on pseudo-labels if self.cfg["training"].get("pseudo_depth_lambda", 0) > 0: # Crop away bottom of image with own car with torch.no_grad(): depth_loss_mask = torch.ones(outputs["disp", 0].shape, device=self.device) depth_loss_mask[:, :, int(outputs["disp", 0].shape[2] * 0.9):, :] = 0 pseudo_depth_loss = berhu(outputs["disp", 0], inputs["pseudo_depth"], depth_loss_mask) pseudo_depth_loss *= self.cfg["training"]["pseudo_depth_lambda"] self.scaler.scale(pseudo_depth_loss).backward(retain_graph=True) else: pseudo_depth_loss = torch.tensor(0) # Train segmentation if self.cfg["training"]["segmentation_lambda"] > 0: with autocast(enabled=self.cfg["training"]["amp"]): segmentation_loss = self.loss_fn(input=outputs["semantics"], target=inputs["lbl"]) if "intermediate_semantics" in outputs: segmentation_loss += self.loss_fn(input=outputs["intermediate_semantics"], target=inputs["lbl"]) segmentation_loss /= 2 segmentation_loss *= self.cfg["training"]["segmentation_lambda"] segmentation_total_loss = segmentation_loss self.scaler.scale(segmentation_total_loss).backward() if self.enable_unlabled_segmentation: unlabeled_loss, unlabeled_mono_loss = self.train_step_segmentation_unlabeled(unlabeled_inputs, step) segmentation_total_loss += unlabeled_loss mono_total_loss += unlabeled_mono_loss if self.cfg["training"].get("clip_grad_norm") is not None: # Unscales the gradients of optimizer's assigned params in-place self.scaler.unscale_(self.optimizer) # Since the gradients of optimizer's assigned params are unscaled, clips as usual: if self.cfg["training"].get("disable_depth_grad_clip", False): torch.nn.utils.clip_grad_norm_(get_params(self.model, ["encoder", "segmentation"]), self.cfg["training"]["clip_grad_norm"]) else: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg["training"]["clip_grad_norm"]) # optimizer's gradients are already unscaled, so scaler.step does not unscale them, # although it still skips optimizer.step() if the gradients contain infs or NaNs. self.scaler.step(self.optimizer) self.scaler.update() if isinstance(self.scheduler, ReduceLROnPlateau): self.scheduler.step(metrics=self.mIoU) else: self.scheduler.step() # update Mean teacher network if self.ema_model is not None: self.ema_model = self.update_ema_variables(ema_model=self.ema_model, model=self.model, alpha_teacher=0.99, iteration=step) total_loss = segmentation_total_loss + mono_total_loss + pseudo_depth_loss return { 'segmentation_loss': segmentation_loss.detach(), 'mono_loss': mono_loss.detach(), 'pseudo_depth_loss': pseudo_depth_loss.detach(), 'feat_dist_loss': feat_dist_loss.detach(), 'segmentation_total_loss': segmentation_total_loss.detach(), 'mono_total_loss': mono_total_loss.detach(), 'total_loss': total_loss.detach() } def setup_segmentation_unlabeled(self): if self.cfg["training"].get("unlabeled_segmentation", None) is None: self.enable_unlabled_segmentation = False return unlabeled_cfg = self.cfg["training"]["unlabeled_segmentation"] self.enable_unlabled_segmentation = True self.consistency_weight = unlabeled_cfg["consistency_weight"] self.mix_mask = unlabeled_cfg.get("mix_mask", None) self.unlabeled_color_jitter = unlabeled_cfg.get("color_jitter") self.unlabeled_blur = unlabeled_cfg.get("blur") self.only_unlabeled = unlabeled_cfg.get("only_unlabeled", True) self.only_labeled = unlabeled_cfg.get("only_labeled", False) self.mix_video = unlabeled_cfg.get("mix_video", False) assert not (self.only_unlabeled and self.only_labeled) self.mix_use_gt = unlabeled_cfg.get("mix_use_gt", False) self.unlabeled_debug_imgs = unlabeled_cfg.get("debug_images", False) self.depthcomp_margin = unlabeled_cfg["depthcomp_margin"] self.depthcomp_foreground_threshold = unlabeled_cfg["depthcomp_foreground_threshold"] self.unlabeled_backward_first_pseudo_label = unlabeled_cfg["backward_first_pseudo_label"] self.depthmix_online_depth = unlabeled_cfg.get("depthmix_online_depth", False) def generate_mix_mask(self, mode, argmax_u_w, unlabeled_imgs, depths): if mode == "class": for image_i in range(self.cfg["training"]["batch_size"]): classes = torch.unique(argmax_u_w[image_i]) classes = classes[classes != 250] nclasses = classes.shape[0] classes = (classes[torch.Tensor( np.random.choice(nclasses, int((nclasses - nclasses % 2) / 2), replace=False)).long()]).cuda() if image_i == 0: MixMask = transformmasks.generate_class_mask(argmax_u_w[image_i], classes).unsqueeze(0).cuda() else: MixMask = torch.cat( (MixMask, transformmasks.generate_class_mask(argmax_u_w[image_i], classes).unsqueeze(0).cuda())) elif self.mix_mask == "depthcomp": assert self.cfg["training"]["batch_size"] == 2 for image_i, other_image_i in [(0, 1), (1, 0)]: own_disp = depths[image_i] other_disp = depths[other_image_i] # Margin avoids too much of mixing road with same depth foreground_mask = torch.ge(own_disp, other_disp - self.depthcomp_margin).long() # Avoid hiding the real background of the other image with own a bit closer background if isinstance(self.depthcomp_foreground_threshold, tuple) or isinstance( self.depthcomp_foreground_threshold, list): ft_l, ft_u = self.depthcomp_foreground_threshold assert ft_u > ft_l ft = torch.rand(1, device=own_disp.device) * (ft_u - ft_l) + ft_l else: ft = self.depthcomp_foreground_threshold foreground_mask *= torch.ge(own_disp, ft).long() if image_i == 0: MixMask = foreground_mask else: MixMask = torch.cat((MixMask, foreground_mask)) elif mode == "depth": for image_i in range(self.cfg["training"]["batch_size"]): generated_depth = depths[image_i] min_depth = 0.1 max_depth = 0.4 depth_threshold = torch.rand(1, device=depths.device) * (max_depth - min_depth) + min_depth if image_i == 0: MixMask = transformmasks.generate_depth_mask(generated_depth, depth_threshold).cuda() else: MixMask = torch.cat( (MixMask, transformmasks.generate_depth_mask(generated_depth, depth_threshold).cuda())) elif mode == "depthhist": for image_i in range(self.cfg["training"]["batch_size"]): generated_depth = depths[image_i] hist, bin_edges = np.histogram(torch.log(1 + generated_depth).flatten(), bins=100, density=True) # Exclude the first bin as it sometimes has a meaningless peak for v, e in zip(np.flip(hist)[1:], np.flip(bin_edges)[1:]): if v > 1.5: max_depth = torch.tensor([e]) break hist = np.cumsum(hist) / np.sum(hist) for v, e in zip(hist, bin_edges): if v > 0.4: min_depth = torch.tensor([e]) break depth_threshold = torch.rand(1) * (max_depth - min_depth) + min_depth if image_i == 0: MixMask = transformmasks.generate_depth_mask(generated_depth, depth_threshold).cuda() else: MixMask = torch.cat( (MixMask, transformmasks.generate_depth_mask(generated_depth, depth_threshold).cuda())) elif mode is None: MixMask = torch.ones((unlabeled_imgs.shape[0], *unlabeled_imgs.shape[2:]), device=self.device) else: raise NotImplementedError(f"Unknown mix_mask {self.mix_mask}") return MixMask def calc_pseudo_label_loss(self, teacher_softmax, student_logits): max_probs, pseudo_label = torch.max(teacher_softmax, dim=1) pseudo_label[max_probs == 0] = self.unlabeled_loader.ignore_index unlabeled_weight = torch.sum(max_probs.ge(0.968).long() == 1).item() / np.prod(pseudo_label.shape) pixelWiseWeight = unlabeled_weight * torch.ones(max_probs.shape, device=self.device) L_u = self.consistency_weight * cross_entropy2d(input=student_logits, target=pseudo_label, pixel_weights=pixelWiseWeight) return L_u, pseudo_label def train_step_segmentation_unlabeled(self, unlabeled_inputs, step): def strongTransform(parameters, data=None, target=None): assert ((data is not None) or (target is not None)) data, target = transformsgpu.mix(mask=parameters["Mix"], data=data, target=target) data, target = transformsgpu.color_jitter(jitter=parameters["ColorJitter"], data=data, target=target) data, target = transformsgpu.gaussian_blur(blur=parameters["GaussianBlur"], data=data, target=None) return data, target unlabeled_imgs = unlabeled_inputs[("color_aug", 0, 0)] # First Step: Run teacher to generate pseudo labels self.ema_model.use_pose_net = False logits_u_w = self.ema_model(unlabeled_inputs)["semantics"] softmax_u_w = torch.softmax(logits_u_w.detach(), dim=1) if self.mix_use_gt: with torch.no_grad(): for i in range(unlabeled_imgs.shape[0]): # .data is necessary to access truth value of tensor if unlabeled_inputs["is_labeled"][i].data: softmax_u_w[i] = unlabeled_inputs["onehot_lbl"][i] _, argmax_u_w = torch.max(softmax_u_w, dim=1) # Second Step: Run student network on unaugmented data to generate depth for DepthMix, calculate monodepth loss, # and unaugmented segmentation pseudo label loss mono_loss = 0 L_1 = 0 if self.depthmix_online_depth: outputs_1 = self.model(unlabeled_inputs) if self.cfg["training"]["monodepth_lambda"] > 0: self.monodepth_loss_calculator_train.generate_images_pred(unlabeled_inputs, outputs_1) mono_losses = self.monodepth_loss_calculator_train.compute_losses(unlabeled_inputs, outputs_1) mono_lambda = self.cfg["training"]["monodepth_lambda"] mono_loss = mono_lambda * mono_losses["loss"] self.scaler.scale(mono_loss).backward(retain_graph=self.unlabeled_backward_first_pseudo_label) depths = outputs_1[("disp", 0)].detach() for j in range(depths.shape[0]): dmin = torch.min(depths[j]) dmax = torch.max(depths[j]) depths[j] = torch.clamp(depths[j], dmin, dmax) depths[j] = (depths[j] - dmin) / (dmax - dmin) else: depths = unlabeled_inputs["pseudo_depth"] if self.unlabeled_backward_first_pseudo_label: logits_1 = outputs_1["semantics"] L_1, _ = self.calc_pseudo_label_loss(teacher_softmax=softmax_u_w, student_logits=logits_1) self.scaler.scale(L_1).backward() elif "pseudo_depth" in unlabeled_inputs: depths = unlabeled_inputs["pseudo_depth"] else: depths = [None] * unlabeled_imgs.shape[0] # Third Step: Run Mix MixMask = self.generate_mix_mask(self.mix_mask, argmax_u_w, unlabeled_imgs, depths) strong_parameters = {"Mix": MixMask} if self.unlabeled_color_jitter: strong_parameters["ColorJitter"] = random.uniform(0, 1) else: strong_parameters["ColorJitter"] = 0 if self.unlabeled_blur: strong_parameters["GaussianBlur"] = random.uniform(0, 1) else: strong_parameters["GaussianBlur"] = 0 inputs_u_s, _ = strongTransform(strong_parameters, data=unlabeled_imgs) unlabeled_inputs[("color_aug", 0, 0)] = inputs_u_s outputs = self.model(unlabeled_inputs) logits_u_s = outputs["semantics"] softmax_u_w_mixed, _ = strongTransform(strong_parameters, data=softmax_u_w) L_2, pseudo_label = self.calc_pseudo_label_loss(teacher_softmax=softmax_u_w_mixed, student_logits=logits_u_s) self.scaler.scale(L_2).backward() for j, (f, img, ps_lab, mask, d) in enumerate( zip(unlabeled_inputs["filename"], inputs_u_s, pseudo_label, MixMask, depths)): if (step + 1) % self.cfg["training"]["print_interval"] != 0: continue fn = f"{self.cfg['training']['log_path']}/class_mix_debug/{step}_{j}_img.jpg" os.makedirs(os.path.dirname(fn), exist_ok=True) rows, cols = 2, 2 fig, axs = plt.subplots(rows, cols, sharex='col', sharey='row', gridspec_kw={'hspace': 0, 'wspace': 0}, figsize=(4 * cols, 4 * rows)) axs[0][0].imshow(img.permute(1, 2, 0).cpu().numpy()) axs[0][1].imshow(mask.float().cpu().numpy(), cmap="gray") if d is not None: axs[1][1].imshow(d[0].cpu().numpy(), cmap="plasma") axs[1][0].imshow(self.val_loader.decode_segmap_tocolor(ps_lab.cpu().numpy())) for ax in axs.flat: ax.axis("off") plt.savefig(fn) plt.close() return L_2 + L_1, mono_loss def train(self): self.start_iter = 0 self.best_iou = -100.0 if self.cfg["training"]["resume"] is not None: self.load_resume() for param_group in self.optimizer.param_groups: param_group['lr'] = self.cfg["training"]["optimizer"]["lr"] train_loss_meter = AverageMeterDict() time_meter = AverageMeter() step = self.start_iter flag = True self.tensorboard_training_images() start_ts = time.time() while step <= self.cfg["training"]["train_iters"] and flag: for inputs in self.train_data_loader: # torch.cuda.empty_cache() step += 1 losses = self.train_step(inputs, step) time_meter.update(time.time() - start_ts) train_loss_meter.update(losses) if (step + 1) % self.cfg["training"]["print_interval"] == 0: fmt_str = "Iter [{}/{}] Loss: {:.4f} Time/Image: {:.4f}" print_str = fmt_str.format( step + 1, self.cfg["training"]["train_iters"], train_loss_meter.avgs["total_loss"], time_meter.avg / self.cfg["training"]["batch_size"], ) self.logger.info(print_str) for k, v in train_loss_meter.avgs.items(): self.writer.add_scalar("training/" + k, v, step + 1) self.writer.add_scalar("training/learning_rate", get_lr(self.optimizer), step + 1) self.writer.add_scalar("training/time_per_image", time_meter.avg / self.cfg["training"]["batch_size"], step + 1) self.writer.add_scalar("training/amp_scale", self.scaler.get_scale(), step + 1) self.writer.add_scalar("training/memory", psutil.virtual_memory().used / 1e9, step + 1) time_meter.reset() train_loss_meter.reset() if (step + 1) % current_val_interval(self.cfg, step + 1) == 0 or (step + 1) == self.cfg["training"][ "train_iters" ]: self.validate(step) if self.mIoU >= self.best_iou: self.best_iou = self.mIoU if self.cfg["training"]["save_model"]: self.save_resume(step) if self.earlyStopping is not None: if not self.earlyStopping.step(self.mIoU): flag = False break if (step + 1) == self.cfg["training"]["train_iters"]: flag = False break start_ts = time.time() return step def validate(self, step): self.model.eval() val_loss_meter = AverageMeterDict() running_metrics_val = runningScore(self.n_classes) imgs_to_save = [] with torch.no_grad(): for inputs_val in tqdm(self.val_data_loader, total=len(self.val_data_loader)): if self.cfg["model"]["disable_monodepth"]: required_inputs = [("color_aug", 0, 0), "lbl"] else: required_inputs = inputs_val.keys() for k, v in inputs_val.items(): if torch.is_tensor(v) and k in required_inputs: inputs_val[k] = v.to(self.device, non_blocking=True) images_val = inputs_val[("color_aug", 0, 0)] with autocast(enabled=self.cfg["training"]["amp"]): outputs = self.model(inputs_val) if self.cfg["training"]["segmentation_lambda"] > 0: labels_val = inputs_val["lbl"] semantics = outputs["semantics"] val_segmentation_loss = self.loss_fn(input=semantics, target=labels_val) # Handle inconsistent size between input and target n, c, h, w = semantics.size() nt, ht, wt = labels_val.size() if h != ht and w != wt: # upsample labels semantics = F.interpolate( semantics, size=(ht, wt), mode="bilinear", align_corners=True ) pred = semantics.data.max(1)[1].cpu().numpy() gt = labels_val.data.cpu().numpy() running_metrics_val.update(gt, pred) else: pred = [None] * images_val.shape[0] gt = [None] * images_val.shape[0] val_segmentation_loss = torch.tensor(0) if not self.cfg["model"]["disable_monodepth"]: if not self.cfg["model"]["disable_pose"]: self.monodepth_loss_calculator_val.generate_images_pred(inputs_val, outputs) mono_losses = self.monodepth_loss_calculator_val.compute_losses(inputs_val, outputs) val_mono_loss = mono_losses["loss"] else: outputs.update(self.model.predict_test_disp(inputs_val)) self.monodepth_loss_calculator_val.generate_depth_test_pred(outputs) val_mono_loss = torch.tensor(0) else: outputs[("disp", 0)] = [None] * images_val.shape[0] val_mono_loss = torch.tensor(0) if self.cfg["data"].get("depth_teacher", None) is not None: # Crop away bottom of image with own car with torch.no_grad(): depth_loss_mask = torch.ones(outputs["disp", 0].shape, device=self.device) depth_loss_mask[:, :, int(outputs["disp", 0].shape[2] * 0.9):, :] = 0 val_pseudo_depth_loss = berhu(outputs["disp", 0], inputs_val["pseudo_depth"], depth_loss_mask, apply_log=self.cfg["training"].get("pseudo_depth_loss_log", False)) else: val_pseudo_depth_loss = torch.tensor(0) val_loss_meter.update({ "segmentation_loss": val_segmentation_loss.detach(), "monodepth_loss": val_mono_loss.detach(), "pseudo_depth_loss": val_pseudo_depth_loss.detach() }) for img, label, output, depth in zip(images_val, gt, pred, outputs[("disp", 0)]): if len(imgs_to_save) < self.cfg["training"]["n_tensorboard_imgs"]: imgs_to_save.append([ img, label, output, depth if depth is None else depth.detach()]) for k, v in val_loss_meter.avgs.items(): self.writer.add_scalar("validation/" + k, v, step + 1) if self.cfg["training"]["segmentation_lambda"] > 0: score, class_iou = running_metrics_val.get_scores() for k, v in score.items(): print(k, v) self.writer.add_scalar("val_metrics/{}".format(k), v, step + 1) for k, v in class_iou.items(): self.writer.add_scalar("val_metrics/cls_{}".format(k), v, step + 1) self.mIoU = score["Mean IoU : \t"] self.fwAcc = score["FreqW Acc : \t"] for j, imgs in enumerate(imgs_to_save): # Only log the first image as they won't change -> save memory if (step + 1) // current_val_interval(self.cfg, step + 1) == 1: self.img_writer.add_image( "{}/{}_0image".format(self.run_id.replace('/', '_'), j), imgs[0], global_step=step + 1) if imgs[1] is not None: colored_image = self.val_loader.decode_segmap_tocolor(imgs[1]) self.img_writer.add_image( "{}/{}_1ground_truth".format(self.run_id.replace('/', '_'), j), colored_image, global_step=step + 1, dataformats="HWC") if imgs[2] is not None: colored_image = self.val_loader.decode_segmap_tocolor(imgs[2]) self.img_writer.add_image( "{}/{}_2prediction".format(self.run_id.replace('/', '_'), j), colored_image, global_step=step + 1, dataformats="HWC") if imgs[3] is not None: colored_image = _colorize(imgs[3], "plasma", max_percentile=100) self.img_writer.add_image( "{}/{}_3depth".format(self.run_id.replace('/', '_'), j), colored_image, global_step=step + 1, dataformats="HWC")
class PyTorchTrainer: def __init__(self, model, device, config, fold_num): self.config = config self.epoch = 0 self.start_epoch = 0 self.fold_num = fold_num if self.config.stage2: self.base_dir = f'./result/stage2/{config.dir}/{config.dir}_fold_{config.fold_num}' else: self.base_dir = f'./result/{config.dir}/{config.dir}_fold_{config.fold_num}' os.makedirs(self.base_dir, exist_ok=True) self.log_path = f'{self.base_dir}/log.txt' self.best_summary_loss = 10**5 self.model = model self.swa_model = AveragedModel(self.model) self.device = device self.wandb = True self.cutmix = self.config.cutmix_ratio self.fmix = self.config.fmix_ratio self.smix = self.config.smix_ratio self.es = EarlyStopping(patience=8) self.scaler = GradScaler() self.amp = self.config.amp param_optimizer = list(self.model.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [ p for n, p in param_optimizer if not any(nd in n for nd in no_decay) ], 'weight_decay': 0.001 }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] self.optimizer, self.scheduler = get_optimizer( self.model, self.config.optimizer_name, self.config.optimizer_params, self.config.scheduler_name, self.config.scheduler_params, self.config.n_epochs) self.criterion = get_criterion(self.config.criterion_name, self.config.criterion_params) self.log(f'Fitter prepared. Device is {self.device}') set_wandb(self.config, fold_num) def fit(self, train_loader, validation_loader): if self.config.FIRST_FREEZE: self.model.freeze() for e in range(self.start_epoch, self.config.n_epochs): if self.config.verbose: lr = self.optimizer.param_groups[0]['lr'] timestamp = datetime.utcnow().isoformat() self.log(f'\n{timestamp}\nLR: {lr}') wandb.log({"Epoch": self.epoch, "lr": lr}, step=e) if self.config.step_scheduler: self.scheduler.step(e) if e >= self.config.START_FREEZE and self.config.FREEZE: print('Model Frozen -> Train Classifier Only') self.model.freeze() self.config.FREEZE = False if e >= self.config.END_FREEZE and self.config.FIRST_FREEZE: print('Model UnFrozen -> Train Classifier Only') self.model.unfreeze() self.config.FIRST_FREEZE = False t = time.time() summary_loss, summary_scores, example_images = self.train_one_epoch( train_loader) torch.cuda.empty_cache() self.log( f'[RESULT]: Train. Epoch: {self.epoch}, Fold Num: {self.fold_num}, summary_loss: {summary_loss.avg:.5f}, summary_acc: {summary_scores.avg}, time: {(time.time() - t):.5f}' ) self.save( f'{self.base_dir}/{self.config.dir}_fold_{self.fold_num}_last-checkpoint.bin' ) wandb.log( { f"Train_loss": summary_loss.avg, f"Train_ACC": summary_scores.avg, f"Example_{self.config.fold_num}": example_images }, step=e) t = time.time() summary_loss, summary_scores = self.validation(validation_loader) torch.cuda.empty_cache() self.log( f'[RESULT]: Val. Epoch: {self.epoch}, summary_loss: {summary_loss.avg:.5f}, summary_acc: {summary_scores.avg}, time: {(time.time() - t):.5f}' ) # if summary_loss.avg < self.best_summary_loss: self.best_summary_loss = summary_loss.avg self.model.eval() self.save( f'{self.base_dir}/{self.config.dir}_fold_{self.config.fold_num}_best-checkpoint-{str(self.epoch).zfill(3)}epoch.bin' ) # for path in sorted(glob(f'{self.base_dir}/{self.config.dir}_fold_{self.config.fold_num}_best-checkpoint-*epoch.bin'))[:-3]: # os.remove(path) if self.config.validation_scheduler: self.scheduler.step(metrics=summary_loss.avg) self.epoch += 1 def validation(self, val_loader): self.model.eval() summary_loss = AverageMeter() summary_acc = AverageMeter() t = time.time() y_true = [] y_pred = [] for step, (images, targets) in enumerate(val_loader): if self.config.verbose: if step % self.config.verbose_step == 0: print( f'Val Step {step}/{len(val_loader)}, ' + \ f'summary_loss: {summary_loss.avg:.5f}, ' + \ f'time: {(time.time() - t):.5f}', end='\r' ) with torch.no_grad(): targets = targets.to(self.device).float() batch_size = images.shape[0] images = images.to(self.device).float() _, outputs = self.model(images) loss = self.criterion(outputs, targets) # targets = targets.argmax(1) y_true.extend(targets.argmax(1).detach().cpu().numpy()) y_pred.extend(outputs.argmax(1).detach().cpu().numpy()) summary_loss.update(loss.detach().item(), batch_size) summary_acc.update( (outputs.argmax(1) == targets.argmax(1)).sum().item() / batch_size, batch_size) wandb.log( { f"Val_loss": summary_loss.avg, f"Val_ACC": summary_acc.avg, }, step=self.epoch) if self.es.step(torch.tensor(summary_loss.avg)): self.log("Stop Early Stopiing") plot_confusion_matrix(y_true, y_pred) exit(0) if self.epoch == self.config.n_epochs - 1: plot_confusion_matrix(y_true, y_pred) return summary_loss, summary_acc def train_one_epoch(self, train_loader): self.model.train() if self.epoch < self.config.freeze_bn_epoch: self.model.freeze_batchnorm_stats() summary_loss = AverageMeter() summary_acc = AverageMeter() example_images = [] t = time.time() for step, (images, targets) in enumerate(train_loader): choice = np.random.rand(1) self.optimizer.zero_grad() if self.config.verbose: if step % self.config.verbose_step == 0: print( f'Train Step {step}/{len(train_loader)}, ' + \ f'summary_loss: {summary_loss.avg:.5f}, ' + \ f'time: {(time.time() - t):.5f}', end='\r' ) targets = targets.to(self.device).float() images = images.to(self.device).float() batch_size = images.shape[0] if self.config.FIRST_FREEZE and self.config.END_FREEZE > self.epoch: if self.amp: with autocast(): _, outputs = self.model(images) loss = self.criterion(outputs, targets) grad_norm = torch.nn.utils.clip_grad_norm_( self.model.parameters(), 1000) self.scaler.scale(loss).backward() self.scaler.step(self.optimizer) self.scaler.update() else: if self.amp: with autocast(): if choice < self.cutmix: aug_images, aug_targets = cutmix( images, targets, 1.) _, outputs = self.model(aug_images) loss = mix_criterion(outputs, aug_targets, self.criterion) elif choice < self.cutmix + self.fmix: aug_images, aug_targets = fmix( images, targets, alpha=1., decay_power=3., shape=self.config.img_size, device=device) aug_images = aug_images.to(self.device).float() _, outputs = self.model(aug_images) loss = mix_criterion(outputs, aug_targets, self.criterion) elif choice < self.cutmix + self.fmix + self.smix: X, ya, yb, lam_a, lam_b = snapmix(images, targets, alpha=0.5, model=self.model) _, outputs, _ = self.model(X) loss = self.snapmix_criterion( self.criterion, outputs, ya, yb, lam_a, lam_b) else: _, outputs = self.model(images) loss = self.criterion(outputs, targets) grad_norm = torch.nn.utils.clip_grad_norm_( self.model.parameters(), 1000) self.scaler.scale(loss).backward() self.scaler.step(self.optimizer) self.scaler.update() else: if choice < self.cutmix: aug_images, aug_targets = cutmix(images, targets, 1.) _, outputs = self.model(aug_images) loss = mix_criterion(outputs, aug_targets, self.criterion) elif choice < self.cutmix + self.fmix: aug_images, aug_targets = fmix( images, targets, alpha=1., decay_power=3., shape=self.config.img_size, device=device) aug_images = aug_images.to(self.device).float() _, outputs = self.model(aug_images) loss = mix_criterion(outputs, aug_targets, self.criterion) elif choice < self.cutmix + self.fmix + self.smix: X, ya, yb, lam_a, lam_b = snapmix(images, targets, alpha=0.5, model=self.model) _, outputs, _ = self.model(X) loss = self.snapmix_criterion(self.criterion, outputs, ya, yb, lam_a, lam_b) else: _, outputs = self.model(images) loss = self.criterion(outputs, targets) grad_norm = torch.nn.utils.clip_grad_norm_( self.model.parameters(), 1000) loss = self.criterion(outputs, targets) loss.backward() self.optimizer.step() if len(example_images) < 16: example_images.append( wandb.Image( images[ 0], # caption=f"Truth: {targets[0].argmax(1).detach().cpu().item()}" )) summary_loss.update(loss.detach().item(), batch_size) summary_acc.update( (outputs.argmax(1) == targets.argmax(1)).sum().item() / batch_size, batch_size) return summary_loss, summary_acc, example_images def predict(self, test_loader, sub): self.model.eval() all_outputs = torch.tensor([], device=self.device) for step, (images, fnames) in enumerate(test_loader): with torch.no_grad(): images = images.to(self.device).float() outputs = self.model.forward(images) all_outputs = torch.cat((all_outputs, outputs), 0) sub.iloc[:, 1] = all_outputs.detach().cpu().numpy() return sub def save(self, path): self.model.eval() torch.save( { 'model_state_dict': self.model.state_dict(), # 'optimizer_state_dict': self.optimizer.state_dict(), # 'scheduler_state_dict': self.scheduler.state_dict(), 'best_summary_loss': self.best_summary_loss, 'epoch': self.epoch, }, path) wandb.save(path.split("/")[-1]) def load(self, path): checkpoint = torch.load(path) self.model.load_state_dict(checkpoint['model_state_dict']) # self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) # self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) self.best_summary_loss = checkpoint['best_summary_loss'] self.epoch = checkpoint['epoch'] + 1 self.start_epoch = checkpoint['epoch'] + 1 def log(self, message): if self.config.verbose: print(message) with open(self.log_path, 'a+') as logger: logger.write(f'{message}\n')
def train(X_train, y_train, X_dev, y_dev, X_test, y_test): num_labels = NUM_EMO vocab_size = VOCAB_SIZE print('NUM of VOCAB' + str(vocab_size)) train_data = EmotionDataLoader(X_train, y_train, PAD_LEN) train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True) dev_data = EmotionDataLoader(X_dev, y_dev, PAD_LEN) dev_loader = DataLoader(dev_data, batch_size=int(BATCH_SIZE / 3) + 2, shuffle=False) test_data = EmotionDataLoader(X_test, y_test, PAD_LEN) test_loader = DataLoader(test_data, batch_size=int(BATCH_SIZE / 3) + 2, shuffle=False) model = AttentionLSTMClassifier(EMBEDDING_DIM, HIDDEN_DIM, vocab_size, num_labels, BATCH_SIZE, att_mode=opt.attention, use_glove=USE_GLOVE) if USE_GLOVE: model.load_embedding(tokenizer.get_embeddings()) # multi-GPU # model = nn.DataParallel(model) model.cuda() if opt.loss == 'ce': loss_criterion = nn.CrossEntropyLoss() # elif opt.loss == 'focal': loss_criterion = FocalLoss(gamma=2, reduce=True) else: raise Exception('loss option not recognised') optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE) es = EarlyStopping(patience=PATIENCE) for epoch in range(1, 300): print('Epoch: ' + str(epoch) + '===================================') train_loss = 0 model.train() for i, (data, seq_len, label) in tqdm(enumerate(train_loader), total=len(train_data) / BATCH_SIZE): optimizer.zero_grad() data_text = [tokenizer.decode_ids(x) for x in data] with torch.no_grad(): character_ids = batch_to_ids(data_text).cuda() elmo_emb = elmo(character_ids)['elmo_representations'] elmo_emb = (elmo_emb[0] + elmo_emb[1]) / 2 # avg of two layers y_pred = model(data.cuda(), seq_len, elmo_emb) loss = loss_criterion(y_pred, label.view(-1).cuda()) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), CLIPS) optimizer.step() train_loss += loss.data.cpu().numpy() * data.shape[0] del y_pred, loss test_loss = 0 model.eval() for _, (_data, _seq_len, _label) in enumerate(dev_loader): with torch.no_grad(): data_text = [tokenizer.decode_ids(x) for x in _data] character_ids = batch_to_ids(data_text).cuda() elmo_emb = elmo(character_ids)['elmo_representations'] elmo_emb = (elmo_emb[0] + elmo_emb[1]) / 2 # avg of two layers y_pred = model(_data.cuda(), _seq_len, elmo_emb) loss = loss_criterion(y_pred, _label.view(-1).cuda()) test_loss += loss.data.cpu().numpy() * _data.shape[0] del y_pred, loss print("Train Loss: " + str(train_loss / len(train_data)) + \ " Evaluation: " + str(test_loss / len(dev_data))) if es.step(test_loss): print('over fitting!') break with open(f'lstm_elmo_{opt.dataset}_model.pt', 'bw') as f: torch.save(model.state_dict(), f) pred_list = [] model.eval() for _, (_data, _seq_len, _label) in enumerate(test_loader): with torch.no_grad(): data_text = [tokenizer.decode_ids(x) for x in _data] character_ids = batch_to_ids(data_text).cuda() elmo_emb = elmo(character_ids)['elmo_representations'] elmo_emb = (elmo_emb[0] + elmo_emb[1]) / 2 # avg of two layers y_pred = model(_data.cuda(), _seq_len, elmo_emb) pred_list.append( y_pred.data.cpu().numpy()) # x[np.where( x > 3.0 )] del y_pred pred_list = np.argmax(np.concatenate(pred_list, axis=0), axis=1) return pred_list