示例#1
0
def train(params, m, datas):
    es = EarlyStopping(min_delta=params.min_delta, patience=params.patience)

    # optimizer
    ps = [p[1] for p in m.named_parameters() if 'discriminator' not in p[0]]
    print('Model parameter: {}'.format(sum(p.numel() for p in ps)))
    optimizer = optim.Adam(ps, lr=params.init_lr)
    if params.adv_training:
        dis_ps = [
            p[1] for p in m.named_parameters() if 'discriminator' in p[0]
        ]
        dis_optimizer = optim.Adam(dis_ps, lr=params.init_lr)
        dis_enc_ps = [
            p[1] for p in m.named_parameters()
            if 'encoder' in p[0] or 'embedding' in p[0]
        ]
        dis_enc_optimizer = optim.Adam(dis_enc_ps, lr=params.init_lr)

    # all training instances, split between 2 languages, right now the data are balanced
    n_batch = len(datas) * datas[0].train_size // params.bs if len(
        datas) * datas[0].train_size % params.bs == 0 else len(
            datas) * datas[0].train_size // params.bs + 1
    data_idxs = {}
    for i, data in enumerate(datas):
        lang = data.vocab.lang
        data_idxs[lang] = list(range(data.train_size))

    # number of iterations
    cur_it = 0
    # write to tensorboard
    writer = SummaryWriter('./history/{}'.format(
        params.log_path)) if params.write_tfboard else None

    nll_dev = math.inf
    best_nll_dev = math.inf
    kld_dev = math.inf

    for i in range(params.ep):
        for lang in data_idxs:
            shuffle(data_idxs[lang])
        for j in range(n_batch):
            if params.task == 'xl' or params.task == 'xl-adv':
                lang_idx = j % len(datas)
                data = datas[lang_idx]
                lang = data.vocab.lang
                train_idxs = data_idxs[lang][j // len(datas) *
                                             params.bs:(j // len(datas) + 1) *
                                             params.bs]
            elif params.task == 'mo':
                lang = params.langs[0]
                lang_idx = params.lang_dict[lang]
                data = datas[lang_idx]
                train_idxs = data_idxs[lang][j * params.bs:(j + 1) * params.bs]
            padded_batch, batch_lens = get_batch(train_idxs, data,
                                                 data.train_idxs,
                                                 data.train_lens, params.cuda)

            optimizer.zero_grad()
            if params.adv_training:
                dis_optimizer.zero_grad()
                dis_enc_optimizer.zero_grad()
            m.train()

            nll_batch, kld_batch, ls_dis, ls_enc = m(lang, padded_batch,
                                                     batch_lens)

            cur_it += 1
            loss_batch, alpha = calc_loss_batch(params, nll_batch, kld_batch,
                                                cur_it, n_batch)
            '''
      # add adversarial loss to the encoder
      if cur_it > params.adv_ep * n_batch:
        loss_batch += ls_enc
      '''

            if not params.adv_training:
                loss_batch.backward()
                optimizer.step()
            else:
                ls_dis = ls_dis.mean()
                ls_enc = ls_enc.mean()
                loss_batch = loss_batch + ls_dis + ls_enc
                loss_batch.backward()
                optimizer.step()
                dis_optimizer.step()
                dis_enc_optimizer.step()

            out_xling(i,
                      j,
                      n_batch,
                      loss_batch,
                      nll_batch,
                      kld_batch,
                      best_nll_dev,
                      nll_dev,
                      kld_dev,
                      es.num_bad_epochs,
                      ls_dis=ls_dis,
                      ls_enc=ls_enc)
            update_tensorboard(writer,
                               loss_batch,
                               nll_batch,
                               kld_batch,
                               alpha,
                               nll_dev,
                               kld_dev,
                               cur_it,
                               ls_dis=ls_dis,
                               ls_enc=ls_enc)

            if cur_it % params.VAL_EVERY == 0:
                sys.stdout.write('\n')
                sys.stdout.flush()
                # validation
                nll_dev, kld_dev = test(params, m, datas)
                if es.step(nll_dev):
                    print('\nEarly Stoped.')
                    return
                elif es.is_better(nll_dev, best_nll_dev):
                    best_nll_dev = nll_dev
                    # save model
                    for lang in params.langs:
                        lang_idx = params.lang_dict[lang]
                        m.save_embedding(params, datas[lang_idx])
                    m.save_model(params, datas)
            for i, (src_test,
                    trg_test) in tqdm(enumerate(test_loader),
                                      total=int(len(test_set) / batch_size)):
                test_logit = model(
                    Variable(src_test).to(device),
                    Variable(trg_test).to(device))
                trg_test = torch.cat((torch.index_select(
                    trg_test, 1, torch.LongTensor(list(range(1, pad_len)))),
                                      torch.LongTensor(
                                          np.zeros([trg_test.shape[0], 1]))),
                                     dim=1)
                test_loss = loss_criterion(
                    test_logit.contiguous().view(-1, vocab_size),
                    Variable(trg_test).view(-1).to(device))
                test_loss_sum += test_loss.item()
                del test_loss, test_logit

        print("Evaluation Loss", test_loss_sum)
        # es.new_loss(test_loss_sum)
        if es.step(test_loss_sum):
            print('Start over fitting')
            break
        # Save Model
        torch.save(
            model.state_dict(),
            open(
                os.path.join(
                    'checkpoint',
                    'new_simple_bar' + '_epoch_%d' % (epoch) + '.model'),
                'wb'))
    def one_fold(num_fold, train_index, dev_index):
        print("Training on fold:", num_fold)
        X_train, X_dev = [X[i] for i in train_index], [X[i] for i in dev_index]
        y_train, y_dev = y[train_index], y[dev_index]

        # construct data loader
        # for one fold, test data comes from k fold split.
        train_data_set = create_data.TrainDataSet(X_train,
                                                  y_train,
                                                  EMAI_PAD_LEN,
                                                  SENT_PAD_LEN,
                                                  word2id,
                                                  emoji_st,
                                                  use_unk=True)

        dev_data_set = create_data.TrainDataSet(X_dev,
                                                y_dev,
                                                EMAI_PAD_LEN,
                                                SENT_PAD_LEN,
                                                word2id,
                                                emoji_st,
                                                use_unk=True)
        dev_data_loader = DataLoader(dev_data_set,
                                     batch_size=BATCH_SIZE,
                                     shuffle=False)
        # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        final_pred_best = None

        # This is to prevent model diverge, once happen, retrain
        while True:

            is_diverged = False
            # Model is defined in HierarchicalPredictor

            if CONTINUE:
                model = torch.load(opt.out_path)
            else:
                model = HierarchicalAttPredictor(SENT_EMB_DIM,
                                                 SENT_HIDDEN_SIZE,
                                                 CTX_LSTM_DIM,
                                                 num_of_vocab,
                                                 SENT_PAD_LEN,
                                                 id2word,
                                                 USE_ELMO=True,
                                                 ADD_LINEAR=False)
                model.load_embedding(emb)
                model.deepmoji_model.load_specific_weights(
                    PRETRAINED_PATH, exclude_names=['output_layer'])

            model.cuda()
            optimizer = optim.Adam(model.parameters(),
                                   lr=learning_rate,
                                   amsgrad=True)
            scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,
                                                               gamma=GAMMA)

            # loss_criterion_binary = nn.CrossEntropyLoss(weight=weight_list_binary)  #
            if loss == 'focal':
                loss_criterion = FocalLoss(gamma=opt.focal)

            elif loss == 'ce':
                loss_criterion = nn.BCELoss()

            es = EarlyStopping(patience=EARLY_STOP_PATIENCE)
            final_pred_list_test = None

            result_print = {}

            for num_epoch in range(MAX_EPOCH):

                # to ensure shuffle at ever epoch
                train_data_loader = DataLoader(train_data_set,
                                               batch_size=BATCH_SIZE,
                                               shuffle=True)

                print('Begin training epoch:', num_epoch, end='...\t')
                sys.stdout.flush()

                # stepping scheduler
                scheduler.step(num_epoch)
                print('Current learning rate', scheduler.get_lr())

                ## Training step
                train_loss = 0
                model.train()

                for i, (a, a_len, emoji_a, e_c) \
                        in tqdm(enumerate(train_data_loader), total=len(train_data_set)/BATCH_SIZE):

                    optimizer.zero_grad()
                    e_c = e_c.type(torch.float)
                    pred = model(a.cuda(), a_len, emoji_a.cuda())
                    loss_label = loss_criterion(pred.squeeze(1),
                                                e_c.view(-1).cuda()).cuda()

                    # training trilogy
                    loss_label.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP)
                    optimizer.step()

                    train_loss += loss_label.data.cpu().numpy() * a.shape[0]
                    del pred, loss_label

                ## Evaluatation step
                model.eval()
                dev_loss = 0
                # pred_list = []
                for i, (a, a_len, emoji_a, e_c) in enumerate(dev_data_loader):

                    with torch.no_grad():
                        e_c = e_c.type(torch.float)
                        pred = model(a.cuda(), a_len, emoji_a.cuda())

                        loss_label = loss_criterion(
                            pred.squeeze(1),
                            e_c.view(-1).cuda()).cuda()

                        dev_loss += loss_label.data.cpu().numpy() * a.shape[0]

                        # pred_list.append(pred.data.cpu().numpy())
                        # gold_list.append(e_c.numpy())
                        del pred, loss_label

                print('Training loss:',
                      train_loss / len(train_data_set),
                      end='\t')
                print('Dev loss:', dev_loss / len(dev_data_set))

                # print(classification_report(gold_list, pred_list, target_names=EMOS))
                # get_metrics(pred_list, gold_list)

                # Gold Test testing
                print('Final test testing...')
                final_pred_list_test = []
                model.eval()

                for i, (a, a_len,
                        emoji_a) in enumerate(final_test_data_loader):

                    with torch.no_grad():

                        pred = model(a.cuda(), a_len, emoji_a.cuda())

                        final_pred_list_test.append(pred.data.cpu().numpy())
                    del a, pred
                print("final_pred_list_test", len(final_pred_list_test))
                final_pred_list_test = np.concatenate(final_pred_list_test,
                                                      axis=0)
                final_pred_list_test = np.squeeze(final_pred_list_test, axis=1)
                print("final_pred_list_test_concat", len(final_pred_list_test))

                accuracy, precision, recall, f1 = get_metrics(
                    np.asarray(final_test_target_list),
                    np.asarray(final_pred_list_test))

                result_print.update(
                    {num_epoch: [accuracy, precision, recall, f1]})

                if dev_loss / len(dev_data_set) > 1.3 and num_epoch > 4:
                    print("Model diverged, retry")
                    is_diverged = True
                    break

                if es.step(dev_loss):  # overfitting
                    print('overfitting, loading best model ...')
                    break
                else:
                    if es.is_best():
                        print('saving best model ...')
                        if final_pred_best is not None:
                            del final_pred_best
                        final_pred_best = deepcopy(final_pred_list_test)

                    else:
                        print('not best model, ignoring ...')
                        if final_pred_best is None:
                            final_pred_best = deepcopy(final_pred_list_test)

            with open(result_path, 'wb') as w:
                pkl.dump(result_print, w)

            if is_diverged:
                print("Reinitialize model ...")
                del model

                continue

            real_test_results.append(np.asarray(final_pred_best))
            # saving model for inference
            torch.save(model, opt.out_path)
            del model
            break
示例#4
0
def train(params, m, datas):
    # early stopping
    es = EarlyStopping(mode='max', patience=params.cldc_patience)
    # set optimizer
    optimizer = get_optimizer(params, m)

    # training on one lang, and dev/test for another lang
    # get training
    train_lang, train_data = get_lang_data(params, datas, training=True)
    # get dev and test, dev is the same language as test
    test_lang, test_data = get_lang_data(params, datas)

    n_batch = train_data.train_size // params.cldc_bs if train_data.train_size % params.cldc_bs == 0 else train_data.train_size // params.cldc_bs + 1
    # per category
    data_idxs = [
        list(range(len(train_idx))) for train_idx in train_data.train_idxs
    ]

    # number of iterations
    cur_it = 0
    # write to tensorboard
    writer = SummaryWriter('./history/{}'.format(
        params.log_path)) if params.write_tfboard else None
    # best xx
    bdev = 0
    btest = 0
    # current xx
    cdev = 0
    ctest = 0
    dev_class_acc = {}
    test_class_acc = {}
    dev_cm = None
    test_cm = None
    # early stopping warm up flag, start es after some iters
    es_flag = False

    for i in range(params.cldc_ep):
        for data_idx in data_idxs:
            shuffle(data_idx)
        for j in range(n_batch):
            train_idxs = []
            for k, data_idx in enumerate(data_idxs):
                if j < n_batch - 1:
                    train_idxs.append(
                        data_idx[int(j * params.cldc_bs *
                                     train_data.train_prop[k]):int(
                                         (j + 1) * params.cldc_bs *
                                         train_data.train_prop[k])])
                elif j == n_batch - 1:
                    train_idxs.append(data_idx[int(j * params.cldc_bs *
                                                   train_data.train_prop[k]):])

            batch_train, batch_train_lens, batch_train_lb = get_batch(
                params, train_idxs, train_data.train_idxs,
                train_data.train_lens)
            optimizer.zero_grad()
            m.train()

            cldc_loss_batch, _, batch_pred = m(train_lang, batch_train,
                                               batch_train_lens,
                                               batch_train_lb)

            batch_acc, batch_acc_cls = get_classification_report(
                params,
                batch_train_lb.data.cpu().numpy(),
                batch_pred.data.cpu().numpy())

            if cldc_loss_batch < params.cldc_lossth:
                es_flag = True

            cldc_loss_batch.backward()
            out_cldc(i, j, n_batch, cldc_loss_batch, batch_acc, batch_acc_cls,
                     bdev, btest, cdev, ctest, es.num_bad_epochs)

            optimizer.step()
            cur_it += 1
            update_tensorboard(writer, cldc_loss_batch, batch_acc, cdev, ctest,
                               dev_class_acc, test_class_acc, cur_it)

            if cur_it % params.CLDC_VAL_EVERY == 0:
                sys.stdout.write('\n')
                sys.stdout.flush()
                # validation
                #cdev, dev_class_acc, dev_cm = test(params, m, test_data.dev_idxs, test_data.dev_lens, test_data.dev_size, test_data.dev_prop, test_lang, cm = True)
                cdev, dev_class_acc, dev_cm = test(params,
                                                   m,
                                                   train_data.dev_idxs,
                                                   train_data.dev_lens,
                                                   train_data.dev_size,
                                                   train_data.dev_prop,
                                                   train_lang,
                                                   cm=True)
                ctest, test_class_acc, test_cm = test(params,
                                                      m,
                                                      test_data.test_idxs,
                                                      test_data.test_lens,
                                                      test_data.test_size,
                                                      test_data.test_prop,
                                                      test_lang,
                                                      cm=True)
                print(dev_cm)
                print(test_cm)
                if es.step(cdev):
                    print('\nEarly Stoped.')
                    return
                elif es.is_better(cdev, bdev):
                    bdev = cdev
                    btest = ctest
                    #save_model(params, m)
                # reset bad epochs
                if not es_flag:
                    es.num_bad_epochs = 0
def main(params, m, data):
    # early stopping
    es = EarlyStopping(mode='max', patience=params.patience)
    # set optimizer
    optimizer = get_optimizer(params, m)

    n_batch = data.train_size // params.bs if data.train_size % params.bs == 0 else data.train_size // params.bs + 1
    # per category
    data_idxs = [list(range(len(train_idx))) for train_idx in data.train_idxs]

    # number of iterations
    cur_it = 0
    # best xx
    bdev = 0
    btest = 0
    # current xx
    cdev = 0
    ctest = 0
    dev_class_acc = {}
    test_class_acc = {}
    dev_cm = None
    test_cm = None
    # early stopping warm up flag, start es after some iters
    es_flag = False

    for i in range(params.ep):
        # self-training
        if params.self_train or i >= params.semi_warm_up:
            params.self_train = True
            first_update = (i == params.semi_warm_up)
            # only for zero-shot
            if first_update:
                es.num_bad_epochs = 0
                es.best = 0
                bdev = 0
                btest = 0
            data = self_train_merge_data(params,
                                         m,
                                         es,
                                         data,
                                         first=first_update)
            n_batch = data.self_train_size // params.bs if data.self_train_size % params.bs == 0 else data.self_train_size // params.bs + 1
            # per category
            data_idxs = [
                list(range(len(train_idx)))
                for train_idx in data.self_train_idxs
            ]

        for data_idx in data_idxs:
            shuffle(data_idx)
        for j in range(n_batch):
            train_idxs = []
            for k, data_idx in enumerate(data_idxs):
                if params.self_train:
                    train_prop = data.self_train_prop
                else:
                    train_prop = data.train_prop
                if j < n_batch - 1:
                    train_idxs.append(
                        data_idx[int(j * params.bs *
                                     train_prop[k]):int((j + 1) * params.bs *
                                                        train_prop[k])])
                elif j == n_batch - 1:
                    train_idxs.append(data_idx[int(j * params.bs *
                                                   train_prop[k]):])

            if params.self_train:
                batch_train, _, batch_train_lb = get_batch(
                    params, train_idxs, data.self_train_idxs,
                    data.self_train_lens)
            else:
                batch_train, _, batch_train_lb = get_batch(
                    params, train_idxs, data.train_idxs, data.train_lens)
            optimizer.zero_grad()
            m.train()

            loss_batch, logits = m(batch_train, labels=batch_train_lb)
            batch_pred = torch.argmax(logits, dim=1)

            batch_acc, batch_acc_cls = get_classification_report(
                params,
                batch_train_lb.data.cpu().numpy(),
                batch_pred.data.cpu().numpy())

            if loss_batch < params.lossth:
                es_flag = True

            loss_batch.backward()
            out_cldc(i, j, n_batch, loss_batch, batch_acc, batch_acc_cls, bdev,
                     btest, cdev, ctest, es.num_bad_epochs)

            optimizer.step()
            cur_it += 1

        sys.stdout.write('\n')
        sys.stdout.flush()
        # validation
        cdev, dev_class_acc, dev_cm = test(params,
                                           m,
                                           data.dev_idxs,
                                           data.dev_lens,
                                           data.dev_size,
                                           data.dev_prop,
                                           cm=True)
        ctest, test_class_acc, test_cm = test(params,
                                              m,
                                              data.test_idxs,
                                              data.test_lens,
                                              data.test_size,
                                              data.test_prop,
                                              cm=True)
        print(dev_cm)
        print(test_cm)
        if es.step(cdev):
            print('\nEarly Stoped.')
            return
        elif es.is_better(cdev, bdev):
            bdev = cdev
            btest = ctest
        # reset bad epochs
        if not es_flag:
            es.num_bad_epochs = 0
    def one_fold(num_fold, train_index, dev_index):
        print("Training on fold:", num_fold)
        X_train, X_dev = [X[i] for i in train_index], [X[i] for i in dev_index]
        y_train, y_dev = y[train_index], y[dev_index]

        # construct data loader
        train_data_set = TrainDataSet(X_train,
                                      y_train,
                                      CONV_PAD_LEN,
                                      SENT_PAD_LEN,
                                      word2id,
                                      use_unk=True)

        dev_data_set = TrainDataSet(X_dev,
                                    y_dev,
                                    CONV_PAD_LEN,
                                    SENT_PAD_LEN,
                                    word2id,
                                    use_unk=True)
        dev_data_loader = DataLoader(dev_data_set,
                                     batch_size=BATCH_SIZE,
                                     shuffle=False)
        # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        pred_list_test_best = None
        final_pred_best = None
        # This is to prevent model diverge, once happen, retrain
        while True:
            is_diverged = False
            model = HierarchicalPredictor(SENT_EMB_DIM,
                                          SENT_HIDDEN_SIZE,
                                          num_of_vocab,
                                          USE_ELMO=True,
                                          ADD_LINEAR=False)
            model.load_embedding(emb)
            model.cuda()
            # model = nn.DataParallel(model)
            # model.to(device)
            optimizer = optim.Adam(model.parameters(),
                                   lr=learning_rate,
                                   amsgrad=True)  #
            # optimizer = optim.SGD(model.parameters(), lr=learning_rate)
            scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,
                                                               gamma=opt.gamma)

            if opt.w == 1:
                weight_list = [0.3, 0.3, 0.3, 1.7]
                weight_list_binary = [0.3, 1.7]
            elif opt.w == 2:
                weight_list = [
                    0.3198680179, 0.246494733, 0.2484349259, 1.74527696
                ]
                weight_list_binary = [2 - weight_list[-1], weight_list[-1]]
            weight_list = [x**FLAT for x in weight_list]
            weight_label = torch.Tensor(weight_list).cuda()

            weight_list_binary = [x**FLAT for x in weight_list_binary]
            weight_binary = torch.Tensor(weight_list_binary).cuda()
            print('classification reweight: ', weight_list)
            print('binary loss reweight = weight_list_binary',
                  weight_list_binary)
            # loss_criterion_binary = nn.CrossEntropyLoss(weight=weight_list_binary)  #
            if opt.loss == 'focal':
                loss_criterion = FocalLoss(gamma=opt.focal, reduce=False)
                loss_criterion_binary = FocalLoss(gamma=opt.focal,
                                                  reduce=False)  #
            elif opt.loss == 'ce':
                loss_criterion = nn.CrossEntropyLoss(reduce=False)
                loss_criterion_binary = nn.CrossEntropyLoss(reduce=False)  #

            loss_criterion_emo_only = nn.MSELoss()

            # es = EarlyStopping(min_delta=0.005, patience=EARLY_STOP_PATIENCE)
            es = EarlyStopping(patience=EARLY_STOP_PATIENCE)
            # best_model = None
            final_pred_list_test = None
            pred_list_test = None
            for num_epoch in range(MAX_EPOCH):
                # to ensure shuffle at ever epoch
                train_data_loader = DataLoader(train_data_set,
                                               batch_size=BATCH_SIZE,
                                               shuffle=True)

                print('Begin training epoch:', num_epoch, end='...\t')
                sys.stdout.flush()

                # stepping scheduler
                scheduler.step(num_epoch)
                print('Current learning rate', scheduler.get_lr())

                train_loss = 0
                model.train()
                for i, (a, a_len, emoji_a, e_c, e_c_binary, e_c_emo) \
                        in tqdm(enumerate(train_data_loader), total=len(train_data_set)/BATCH_SIZE):
                    optimizer.zero_grad()
                    elmo_a = elmo_encode(a)

                    pred, pred2, pred3 = model(a.cuda(), a_len, emoji_a.cuda(),
                                               elmo_a)

                    loss_label = loss_criterion(pred,
                                                e_c.view(-1).cuda()).cuda()
                    loss_label = torch.matmul(torch.gather(weight_label, 0, e_c.view(-1).cuda()), loss_label) / \
                                 e_c.view(-1).shape[0]

                    loss_binary = loss_criterion_binary(
                        pred2,
                        e_c_binary.view(-1).cuda()).cuda()
                    loss_binary = torch.matmul(
                        torch.gather(weight_binary, 0,
                                     e_c_binary.view(-1).cuda()),
                        loss_binary) / e_c.view(-1).shape[0]

                    loss_emo = loss_criterion_emo_only(pred3, e_c_emo.cuda())

                    loss = (loss_label + LAMBDA1 * loss_binary +
                            LAMBDA2 * loss_emo) / float(1 + LAMBDA1 + LAMBDA2)

                    # loss = torch.matmul(torch.gather(weight, 0, trg.view(-1).cuda()), loss) / trg.view(-1).shape[0]

                    # training trilogy
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP)
                    optimizer.step()

                    train_loss += loss.data.cpu().numpy() * a.shape[0]
                    del pred, loss, elmo_a, e_c_emo, loss_binary, loss_label, loss_emo

                # Evaluate
                model.eval()
                dev_loss = 0
                # pred_list = []
                # gold_list = []
                for i, (a, a_len, emoji_a, e_c, e_c_binary, e_c_emo) \
                        in enumerate(dev_data_loader):
                    with torch.no_grad():
                        elmo_a = elmo_encode(a)

                        pred, pred2, pred3 = model(a.cuda(), a_len,
                                                   emoji_a.cuda(), elmo_a)

                        loss_label = loss_criterion(
                            pred,
                            e_c.view(-1).cuda()).cuda()
                        loss_label = torch.matmul(
                            torch.gather(weight_label, 0,
                                         e_c.view(-1).cuda()),
                            loss_label) / e_c.view(-1).shape[0]

                        loss_binary = loss_criterion_binary(
                            pred2,
                            e_c_binary.view(-1).cuda()).cuda()
                        loss_binary = torch.matmul(
                            torch.gather(weight_binary, 0,
                                         e_c_binary.view(-1).cuda()),
                            loss_binary) / e_c.view(-1).shape[0]

                        loss_emo = loss_criterion_emo_only(
                            pred3, e_c_emo.cuda())

                        loss = (loss_label + LAMBDA1 * loss_binary + LAMBDA2 *
                                loss_emo) / float(1 + LAMBDA1 + LAMBDA2)

                        dev_loss += loss.data.cpu().numpy() * a.shape[0]

                        # pred_list.append(pred.data.cpu().numpy())
                        # gold_list.append(e_c.numpy())
                        del pred, loss, elmo_a, e_c_emo, loss_binary, loss_label, loss_emo

                print('Training loss:',
                      train_loss / len(train_data_set),
                      end='\t')
                print('Dev loss:', dev_loss / len(dev_data_set))
                # print(classification_report(gold_list, pred_list, target_names=EMOS))
                # get_metrics(pred_list, gold_list)
                if dev_loss / len(dev_data_set) > 1.3 and num_epoch > 4:
                    print("Model diverged, retry")
                    is_diverged = True
                    break

                if es.step(dev_loss):  # overfitting
                    print('overfitting, loading best model ...')
                    break
                else:
                    if es.is_best():
                        print('saving best model ...')
                        if final_pred_best is not None:
                            del final_pred_best
                        final_pred_best = deepcopy(final_pred_list_test)
                        if pred_list_test_best is not None:
                            del pred_list_test_best
                        pred_list_test_best = deepcopy(pred_list_test)
                    else:
                        print('not best model, ignoring ...')
                        if final_pred_best is None:
                            final_pred_best = deepcopy(final_pred_list_test)
                        if pred_list_test_best is None:
                            pred_list_test_best = deepcopy(pred_list_test)

                # Gold Dev testing...
                print('Gold Dev testing....')
                pred_list_test = []
                model.eval()
                for i, (a, a_len, emoji_a) in enumerate(gold_dev_data_loader):
                    with torch.no_grad():
                        elmo_a = elmo_encode(a)  # , __id2word=ex_id2word

                        pred, _, _ = model(a.cuda(), a_len, emoji_a.cuda(),
                                           elmo_a)

                        pred_list_test.append(pred.data.cpu().numpy())
                    del elmo_a, a, pred
                pred_list_test = np.argmax(np.concatenate(pred_list_test,
                                                          axis=0),
                                           axis=1)
                # get_metrics(load_dev_labels('data/dev.txt'), pred_list_test)

                # Testing
                print('Gold test testing...')
                final_pred_list_test = []
                model.eval()
                for i, (a, a_len, emoji_a) in enumerate(test_data_loader):
                    with torch.no_grad():
                        elmo_a = elmo_encode(a)  # , __id2word=ex_id2word

                        pred, _, _ = model(a.cuda(), a_len, emoji_a.cuda(),
                                           elmo_a)

                        final_pred_list_test.append(pred.data.cpu().numpy())
                    del elmo_a, a, pred
                final_pred_list_test = np.argmax(np.concatenate(
                    final_pred_list_test, axis=0),
                                                 axis=1)
                # get_metrics(load_dev_labels('data/test.txt'), final_pred_list_test)

            if is_diverged:
                print("Reinitialize model ...")
                del model
                continue

            all_fold_results.append(pred_list_test_best)
            real_test_results.append(final_pred_best)
            del model
            break
示例#7
0
    def one_fold(num_fold, train_index, dev_index):
        print("Training on fold:", num_fold)
        X_train, X_dev = [X[i] for i in train_index], [X[i] for i in dev_index]
        y_train, y_dev = y[train_index], y[dev_index]

        # construct data loader
        train_data_set = DataSet(X_train, y_train, SENT_PAD_LEN)
        train_data_loader = DataLoader(train_data_set, batch_size=BATCH_SIZE, shuffle=True)

        dev_data_set = DataSet(X_dev, y_dev, SENT_PAD_LEN)
        dev_data_loader = DataLoader(dev_data_set, batch_size=BATCH_SIZE, shuffle=False)
        gradient_accumulation_steps = 1
        num_train_steps = int(
            len(train_data_set) / BATCH_SIZE / gradient_accumulation_steps * MAX_EPOCH)

        pred_list_test_best = None
        final_pred_best = None
        # This is to prevent model diverge, once happen, retrain
        while True:
            is_diverged = False
            model = BERT_classifer.from_pretrained(BERT_MODEL)
            model.add_output_layer(BERT_MODEL, NUM_EMO)
            model = nn.DataParallel(model)
            if HALF_PRECISION:
                # model = network_to_half(model)
                model.half()
            model.to(device)
            #model.cpu()

            # BERT optimizer
            param_optimizer = list(model.named_parameters())
            no_decay = ['bias', 'gamma', 'beta']
            optimizer_grouped_parameters = [
                {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
                 'weight_decay_rate': 0.01},
                {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.0}
            ]

            optimizer = BertAdam(optimizer_grouped_parameters,
                                 lr=learning_rate,
                                 warmup=0.1,
                                 t_total=num_train_steps)

            if opt.w == 1:
                weight_list = [0.3, 0.3, 0.3, 1.7]
                weight_list_binary = [2 - weight_list[-1], weight_list[-1]]
            elif opt.w == 2:
                weight_list = [0.3198680179, 0.246494733, 0.2484349259, 1.74527696]
                weight_list_binary = [2 - weight_list[-1], weight_list[-1]]

            weight_list = [x**FLAT for x in weight_list]
            weight_label = torch.Tensor(weight_list).to(device)

            weight_list_binary = [x**FLAT for x in weight_list_binary]
            weight_binary = torch.Tensor(weight_list_binary).to(device)
            print('binary loss reweight = weight_list_binary', weight_list_binary)
            # loss_criterion_binary = nn.CrossEntropyLoss(weight=weight_list_binary)  #
            if opt.loss == 'focal':
                loss_criterion = FocalLoss(gamma=opt.focal, reduce=False)
                loss_criterion_binary = FocalLoss(gamma=opt.focal, reduce=False)  #
            elif opt.loss == 'ce':
                loss_criterion = nn.CrossEntropyLoss(reduce=False)
                loss_criterion_binary = nn.CrossEntropyLoss(reduce=False)  #

            loss_criterion_emo_only = nn.MSELoss()

            # es = EarlyStopping(min_delta=0.005, patience=EARLY_STOP_PATIENCE)
            es = EarlyStopping(patience=EARLY_STOP_PATIENCE)
            final_pred_best = None
            final_pred_list_test = None
            pred_list_test = None
            for num_epoch in range(MAX_EPOCH):
                print('Begin training epoch:', num_epoch)
                sys.stdout.flush()
                train_loss = 0
                model.train()
                for i, (tokens, masks, segments, e_c, e_c_binary, e_c_emo) in tqdm(enumerate(train_data_loader),
                                                              total=len(train_data_set)/BATCH_SIZE):
                    optimizer.zero_grad()

                    if USE_TOKEN_TYPE:
                        pred, pred2, pred3 = model(tokens.to(device), masks.to(device), segments.to(device))
                    else:
                        pred, pred2, pred3 = model(tokens.to(device), masks.to(device))

                    loss_label = loss_criterion(pred, e_c.view(-1).to(device)).to(device)
                    loss_label = torch.matmul(torch.gather(weight_label, 0, e_c.view(-1).to(device)), loss_label) / \
                                 e_c.view(-1).shape[0]

                    loss_binary = loss_criterion_binary(pred2, e_c_binary.view(-1).to(device)).to(device)
                    loss_binary = torch.matmul(torch.gather(weight_binary, 0, e_c_binary.view(-1).to(device)),
                                               loss_binary) / e_c.view(-1).shape[0]

                    loss_emo = loss_criterion_emo_only(pred3, e_c_emo.to(device))

                    loss = (loss_label + LAMBDA1 * loss_binary + LAMBDA2 * loss_emo) / float(1 + LAMBDA1 + LAMBDA2)

                    # training trilogy
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP)
                    optimizer.step()

                    train_loss += loss.data.cpu().numpy() * tokens.shape[0]

                    del loss, pred

                # Evaluate
                model.eval()
                dev_loss = 0
                # pred_list = []
                # gold_list = []
                for i, (tokens, masks, segments, e_c, e_c_binary, e_c_emo) in enumerate(dev_data_loader):
                    with torch.no_grad():
                        if USE_TOKEN_TYPE:
                            pred, pred2, pred3 = model(tokens.to(device), masks.to(device), segments.to(device))
                        else:
                            pred, pred2, pred3 = model(tokens.to(device), masks.to(device))

                        loss_label = loss_criterion(pred, e_c.view(-1).to(device)).to(device)
                        loss_label = torch.matmul(torch.gather(weight_label, 0, e_c.view(-1).to(device)), loss_label) / \
                                     e_c.view(-1).shape[0]

                        loss_binary = loss_criterion_binary(pred2, e_c_binary.view(-1).to(device)).to(device)
                        loss_binary = torch.matmul(torch.gather(weight_binary, 0, e_c_binary.view(-1).to(device)),
                                                   loss_binary) / e_c.view(-1).shape[0]

                        loss_emo = loss_criterion_emo_only(pred3, e_c_emo.to(device))

                        loss = (loss_label + LAMBDA1 * loss_binary + LAMBDA2 * loss_emo) / float(1 + LAMBDA1 + LAMBDA2)

                        dev_loss += loss.data.cpu().numpy() * tokens.shape[0]

                        # pred_list.append(pred.data.cpu().numpy())
                        # gold_list.append(e_c.numpy())
                        del pred, loss

                # pred_list = np.argmax(np.concatenate(pred_list, axis=0), axis=1)
                # gold_list = np.concatenate(gold_list, axis=0)
                print('Training loss:', train_loss / len(train_data_set), end='\t')
                print('Dev loss:', dev_loss / len(dev_data_set))
                # print(classification_report(gold_list, pred_list, target_names=EMOS))
                # get_metrics(pred_list, gold_list)
                # checking diverge
                if dev_loss/len(dev_data_set) > 1.3 and num_epoch > 4:
                    print("Model diverged, retry")
                    is_diverged = True
                    break

                if es.step(dev_loss):  # overfitting
                    print('overfitting, loading best model ...')
                    if num_epoch == 1:
                        is_diverged = True
                        final_pred_best = deepcopy(final_pred_list_test)
                        pred_list_test_best = deepcopy(pred_list_test)
                    break
                else:
                    if es.is_best():
                        print('saving best model ...')
                        if final_pred_best is not None:
                            del final_pred_best
                        final_pred_best = deepcopy(final_pred_list_test)
                        if pred_list_test_best is not None:
                            del pred_list_test_best
                        pred_list_test_best = deepcopy(pred_list_test)
                    else:
                        print('not best model, ignoring ...')
                        if final_pred_best is None:
                            final_pred_best = deepcopy(final_pred_list_test)
                        if pred_list_test_best is None:
                            pred_list_test_best = deepcopy(pred_list_test)

                print('Gold Dev ...')
                pred_list_test = []
                model.eval()
                for i, (tokens, masks, segments, e_c, e_c_binary, e_c_emo) in enumerate(gold_dev_data_loader):
                    with torch.no_grad():
                        if USE_TOKEN_TYPE:
                            pred, _, _ = model(tokens.to(device), masks.to(device), segments.to(device))
                        else:
                            pred, _, _ = model(tokens.to(device), masks.to(device))
                        pred_list_test.append(pred.data.cpu().numpy())

                pred_list_test = np.argmax(np.concatenate(pred_list_test, axis=0), axis=1)
                # get_metrics(load_dev_labels('data/dev.txt'), pred_list_test)

                print('Gold Test ...')
                final_pred_list_test = []
                model.eval()
                for i, (tokens, masks, segments, e_c, e_c_binary, e_c_emo) in enumerate(gold_test_data_loader):
                    with torch.no_grad():
                        if USE_TOKEN_TYPE:
                            pred, _, _ = model(tokens.to(device), masks.to(device), segments.to(device))
                        else:
                            pred, _, _ = model(tokens.to(device), masks.to(device))
                        final_pred_list_test.append(pred.data.cpu().numpy())

                final_pred_list_test = np.argmax(np.concatenate(final_pred_list_test, axis=0), axis=1)
                # get_metrics(load_dev_labels('data/test.txt'), final_pred_list_test)

            if is_diverged:
                print("Reinitialize model ...")
                del model
                continue
            all_fold_results.append(pred_list_test_best)
            real_test_results.append(final_pred_best)

            del model
            break
示例#8
0
def train(X_train, y_train, X_dev, y_dev, X_test, y_test):
    num_labels = NUM_EMO

    vocab_size = VOCAB_SIZE

    print('NUM of VOCAB' + str(vocab_size))
    train_data = EmotionDataLoader(X_train, y_train, PAD_LEN)
    train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)

    dev_data = EmotionDataLoader(X_dev, y_dev, PAD_LEN)
    dev_loader = DataLoader(dev_data, batch_size=int(BATCH_SIZE/3)+2, shuffle=False)

    test_data = EmotionDataLoader(X_test, y_test, PAD_LEN)
    test_loader = DataLoader(test_data, batch_size=int(BATCH_SIZE/3)+2, shuffle=False)

    model = AttentionLSTMClassifier(EMBEDDING_DIM, HIDDEN_DIM, vocab_size,
                                    num_labels, BATCH_SIZE, att_mode=opt.attention, soft_last=False)

    model.load_embedding(tokenizer.get_embeddings())
    # multi-GPU
    # model = nn.DataParallel(model)
    model.cuda()

    loss_criterion = nn.CrossEntropyLoss()  #

    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    es = EarlyStopping(patience=PATIENCE)
    old_model = None
    for epoch in range(1, 300):
        print('Epoch: ' + str(epoch) + '===================================')
        train_loss = 0
        model.train()
        for i, (data, seq_len, label) in tqdm(enumerate(train_loader),
                                              total=len(train_data)/BATCH_SIZE):
            optimizer.zero_grad()
            y_pred = model(data.cuda(), seq_len)
            loss = loss_criterion(y_pred, label.view(-1).cuda())
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), CLIPS)
            optimizer.step()
            train_loss += loss.data.cpu().numpy() * data.shape[0]
            del y_pred, loss

        test_loss = 0
        model.eval()
        for _, (_data, _seq_len, _label) in enumerate(dev_loader):
            with torch.no_grad():
                y_pred = model(_data.cuda(), _seq_len)
                loss = loss_criterion(y_pred, _label.view(-1).cuda())
                test_loss += loss.data.cpu().numpy() * _data.shape[0]
                del y_pred, loss

        print("Train Loss: " + str(train_loss / len(train_data)) + \
              " Evaluation: " + str(test_loss / len(dev_data)))

        if es.step(test_loss):  # overfitting
            del model
            print('overfitting, loading best model ...')
            model = old_model
            break
        else:
            if es.is_best():
                if old_model is not None:
                    del old_model
                print('saving best model ...')
                old_model = deepcopy(model)
            else:
                print('not best model, ignoring ...')
                if old_model is None:
                    old_model = deepcopy(model)

    with open(f'lstm_{opt.dataset}_model.pt', 'bw') as f:
        torch.save(model.state_dict(), f)

    pred_list = []
    model.eval()
    for _, (_data, _seq_len, _label) in enumerate(test_loader):
        with torch.no_grad():
            y_pred = model(_data.cuda(), _seq_len)
            pred_list.append(y_pred.data.cpu().numpy())  # x[np.where( x > 3.0 )]
            del y_pred

    pred_list = np.argmax(np.concatenate(pred_list, axis=0), axis=1)

    return pred_list
示例#9
0
class Trainer():
    def __init__(self, cfg, writer, img_writer, logger, run_id):
        # Copy shared config fields
        if "monodepth_options" in cfg:
            cfg["data"].update(cfg["monodepth_options"])
            cfg["model"].update(cfg["monodepth_options"])
            cfg["training"]["monodepth_loss"].update(cfg["monodepth_options"])
        if "generated_depth_dir" in cfg["data"]:
            dataset_name = f"{cfg['data']['dataset']}_" \
                           f"{cfg['data']['width']}x{cfg['data']['height']}"
            depth_teacher = cfg["data"].get("depth_teacher", None)
            assert not (depth_teacher and cfg['model'].get('detph_estimator_weights') is not None)
            if depth_teacher is not None:
                cfg["data"]["generated_depth_dir"] += dataset_name + "/" + depth_teacher + "/"
            else:
                cfg["data"]["generated_depth_dir"] += dataset_name + "/" + cfg['model']['depth_estimator_weights'] + "/"

        # Setup seeds
        setup_seeds(cfg.get("seed", 1337))
        if cfg["data"]["dataset_seed"] == "same":
            cfg["data"]["dataset_seed"] = cfg["seed"]

        # Setup device
        torch.backends.cudnn.benchmark = cfg["training"].get("benchmark", True)
        self.cfg = cfg
        self.writer = writer
        self.img_writer = img_writer
        self.logger = logger
        self.run_id = run_id
        self.mIoU = 0
        self.fwAcc = 0
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.setup_segmentation_unlabeled()

        self.unlabeled_require_depth = (self.cfg["training"]["unlabeled_segmentation"] is not None and
                                        (self.cfg["training"]["unlabeled_segmentation"]["mix_mask"] == "depth" or
                                         self.cfg["training"]["unlabeled_segmentation"]["mix_mask"] == "depthcomp" or
                                         self.cfg["training"]["unlabeled_segmentation"]["mix_mask"] == "depthhist"))

        # Prepare depth estimates
        do_precalculate_depth = self.cfg["training"]["segmentation_lambda"] != 0 and self.unlabeled_require_depth and \
                                self.cfg['model']['segmentation_name'] != 'mtl_pad'
        use_depth_teacher = cfg["data"].get("depth_teacher", None) is not None
        if do_precalculate_depth or use_depth_teacher:
            assert not (do_precalculate_depth and use_depth_teacher)
            if not self.cfg["training"].get("disable_depth_estimator", False):
                print("Prepare depth estimates")
                depth_estimator = DepthEstimator(cfg)
                depth_estimator.prepare_depth_estimates()
                del depth_estimator
                torch.cuda.empty_cache()
        else:
            self.cfg["data"]["generated_depth_dir"] = None

        # Setup Dataloader
        load_labels, load_sequence = True, True
        if self.cfg["training"]["monodepth_lambda"] == 0:
            load_sequence = False
        if self.cfg["training"]["segmentation_lambda"] == 0:
            load_labels = False
        train_data_cfg = deepcopy(self.cfg["data"])
        if not do_precalculate_depth and not use_depth_teacher:
            train_data_cfg["generated_depth_dir"] = None
        self.train_loader = build_loader(train_data_cfg, "train", load_labels=load_labels, load_sequence=load_sequence)
        if self.cfg["training"].get("minimize_entropy_unlabeled", False) or self.enable_unlabled_segmentation:
            unlabeled_segmentation_cfg = deepcopy(self.cfg["data"])
            if not self.only_unlabeled and self.mix_use_gt:
                unlabeled_segmentation_cfg["load_onehot"] = True
            if self.only_unlabeled:
                unlabeled_segmentation_cfg.update({"load_unlabeled": True, "load_labeled": False})
            elif self.only_labeled:
                unlabeled_segmentation_cfg.update({"load_unlabeled": False, "load_labeled": True})
            else:
                unlabeled_segmentation_cfg.update({"load_unlabeled": True, "load_labeled": True})
            if self.mix_video:
                assert not self.mix_use_gt and not self.only_labeled and not self.only_unlabeled, \
                    "Video sample indices are not compatible with non-video indices."
                unlabeled_segmentation_cfg.update({"only_sequences_with_segmentation": not self.mix_video,
                                                   "restrict_to_subset": None})
            self.unlabeled_loader = build_loader(unlabeled_segmentation_cfg, "train",
                                                 load_labels=load_labels if not self.mix_video else False,
                                                 load_sequence=load_sequence)
        else:
            self.unlabeled_loader = None
        self.val_loader = build_loader(self.cfg["data"], "val", load_labels=load_labels,
                                       load_sequence=load_sequence)
        self.n_classes = self.train_loader.n_classes

        # monodepth dataloader settings uses drop_last=True and shuffle=True even for val
        self.train_data_loader = data.DataLoader(
            self.train_loader,
            batch_size=self.cfg["training"]["batch_size"],
            num_workers=self.cfg["training"]["n_workers"],
            shuffle=self.cfg["data"]["shuffle_trainset"],
            pin_memory=True,
            # Setting to false will cause crash at the end of epoch
            drop_last=True,
        )
        if self.unlabeled_loader is not None:
            self.unlabeled_data_loader = infinite_iterator(data.DataLoader(
                self.unlabeled_loader,
                batch_size=self.cfg["training"]["batch_size"],
                num_workers=self.cfg["training"]["n_workers"],
                shuffle=self.cfg["data"]["shuffle_trainset"],
                pin_memory=True,
                # Setting to false will cause crash at the end of epoch
                drop_last=True,
            ))

        self.val_batch_size = self.cfg["training"]["val_batch_size"]
        self.val_data_loader = data.DataLoader(
            self.val_loader,
            batch_size=self.val_batch_size,
            num_workers=self.cfg["training"]["n_workers"],
            pin_memory=True,
            # If using a dataset with odd number of samples (CamVid), the memory consumption suddenly increases for the
            # last batch. This can be circumvented by dropping the last batch. Only do that if it is necessary for your
            # system as it will result in an incomplete validation set.
            # drop_last=True,
        )

        # Setup Model
        self.model = get_model(cfg["model"], self.n_classes).to(self.device)
        # print(self.model)
        assert not (self.enable_unlabled_segmentation and self.cfg["training"]["save_monodepth_ema"])
        if self.enable_unlabled_segmentation and not self.only_labeled:
            print("Create segmentation ema model.")
            self.ema_model = self.create_ema_model(self.model).to(self.device)
        elif self.cfg["training"]["save_monodepth_ema"]:
            print("Create depth ema model.")
            # TODO: Try to remove unnecessary components and fit into gpu for better performance
            self.ema_model = self.create_ema_model(self.model)  # .to(self.device)
        else:
            self.ema_model = None

        # Setup optimizer, lr_scheduler and loss function
        optimizer_cls = get_optimizer(cfg)
        optimizer_params = {k: v for k, v in cfg["training"]["optimizer"].items() if
                            k not in ["name", "backbone_lr", "pose_lr", "depth_lr", "segmentation_lr"]}
        train_params = get_train_params(self.model, self.cfg)
        self.optimizer = optimizer_cls(train_params, **optimizer_params)

        self.scheduler = get_scheduler(self.optimizer, self.cfg["training"]["lr_schedule"])

        # Creates a GradScaler once at the beginning of training.
        self.scaler = GradScaler(enabled=self.cfg["training"]["amp"])

        self.loss_fn = get_segmentation_loss_function(self.cfg)
        self.monodepth_loss_calculator_train = get_monodepth_loss(self.cfg, is_train=True)
        self.monodepth_loss_calculator_val = get_monodepth_loss(self.cfg, is_train=False, batch_size=self.val_batch_size)

        if cfg["training"]["early_stopping"] is None:
            logger.info("Using No Early Stopping")
            self.earlyStopping = None
        else:
            self.earlyStopping = EarlyStopping(
                patience=round(cfg["training"]["early_stopping"]["patience"] / cfg["training"]["val_interval"]),
                min_delta=cfg["training"]["early_stopping"]["min_delta"],
                cumulative_delta=cfg["training"]["early_stopping"]["cum_delta"],
                logger=logger
            )

    def extract_monodepth_ema_params(self, model, ema_model):
        model_names = ["depth"]
        if not self.cfg["model"]["freeze_backbone"]:
            model_names.append("encoder")

        return extract_ema_params(model, ema_model, model_names)

    def extract_pad_ema_params(self, model, ema_model):
        model_names = ["depth", "encoder", "mtl_decoder"]
        return extract_ema_params(model, ema_model, model_names)

    def create_ema_model(self, model):
        ema_cfg = deepcopy(self.cfg["model"])
        ema_cfg["disable_pose"] = True
        ema_model = get_model(ema_cfg, self.n_classes)
        if self.cfg["training"]["save_monodepth_ema"]:
            mp, mcp = self.extract_monodepth_ema_params(model, ema_model)
        elif self.cfg['model']['segmentation_name'] == 'mtl_pad':
            mp, mcp = self.extract_pad_ema_params(model, ema_model)
        else:
            mp, mcp = list(model.parameters()), list(ema_model.parameters())
        for param in mcp:
            param.detach_()
        assert len(mp) == len(mcp), f"len(mp)={len(mp)}; len(mcp)={len(mcp)}"
        n = len(mp)
        for i in range(0, n):
            mcp[i].data[:] = mp[i].to(mcp[i].device, non_blocking=True).data[:].clone()
        return ema_model

    def update_ema_variables(self, ema_model, model, alpha_teacher, iteration):
        if self.cfg["training"]["save_monodepth_ema"]:
            model_params, ema_params = self.extract_monodepth_ema_params(model, ema_model)
        elif self.cfg['model']['segmentation_name'] == 'mtl_pad':
            model_params, ema_params = self.extract_pad_ema_params(model, ema_model)
        else:
            model_params, ema_params = model.parameters(), ema_model.parameters()
        # Use the "true" average until the exponential average is more correct
        alpha_teacher = min(1 - 1 / (iteration + 1), alpha_teacher)
        for ema_param, param in zip(ema_params, model_params):
            ema_param.data[:] = alpha_teacher * ema_param[:].data[:] + \
                                (1 - alpha_teacher) * param.to(ema_param.device, non_blocking=True)[:].data[:]
        return ema_model

    def save_resume(self, step):
        if self.ema_model is not None:
            raise NotImplementedError("ema model not supported")
        state = {
            "epoch": step + 1,
            "model_state": self.model.state_dict(),
            "optimizer_state": self.optimizer.state_dict(),
            "scheduler_state": self.scheduler.state_dict(),
            "best_iou": self.best_iou,
        }
        save_path = os.path.join(
            self.writer.file_writer.get_logdir(),
            "best_model.pkl"
        )
        torch.save(state, save_path)
        return save_path

    def save_monodepth_models(self):
        if self.cfg["training"]["save_monodepth_ema"]:
            print("Save ema monodepth models.")
            assert self.ema_model is not None
            model_to_save = self.ema_model
        else:
            model_to_save = self.model
        models = ["depth", "pose_encoder", "pose"]
        if not self.cfg["model"]["freeze_backbone"]:
            models.append("encoder")
        for model_name in models:
            save_path = os.path.join(self.writer.file_writer.get_logdir(), "{}.pth".format(model_name))
            to_save = model_to_save.models[model_name].state_dict()
            torch.save(to_save, save_path)

    def load_resume(self, strict=True, load_model_only=False):
        if os.path.isfile(self.cfg["training"]["resume"]):
            self.logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(self.cfg["training"]["resume"])
            )
            checkpoint = torch.load(self.cfg["training"]["resume"])
            self.model.load_state_dict(checkpoint["model_state"], strict=strict)
            if not load_model_only:
                self.optimizer.load_state_dict(checkpoint["optimizer_state"])
                self.scheduler.load_state_dict(checkpoint["scheduler_state"])
            self.start_iter = checkpoint["epoch"]
            self.best_iou = checkpoint["best_iou"]
            self.logger.info(
                "Loaded checkpoint '{}' (iter {})".format(
                    self.cfg["training"]["resume"], checkpoint["epoch"]
                )
            )
        else:
            self.logger.info("No checkpoint found at '{}'".format(self.cfg["training"]["resume"]))

    def tensorboard_training_images(self):
        num_saved = 0
        if self.cfg["training"]["n_tensorboard_trainimgs"] == 0:
            return
        for inputs in self.train_data_loader:
            images = inputs[("color_aug", 0, 0)]
            labels = inputs["lbl"]
            for img, label in zip(images.numpy(), labels.numpy()):
                if num_saved < self.cfg["training"]["n_tensorboard_trainimgs"]:
                    num_saved += 1
                    self.img_writer.add_image(
                        "trainset_{}/{}_0image".format(self.run_id.replace('/', '_'), num_saved), img,
                        global_step=0)
                    colored_image = self.val_loader.decode_segmap_tocolor(label)
                    self.img_writer.add_image(
                        "trainset_{}/{}_1ground_truth".format(self.run_id.replace('/', '_'), num_saved),
                        colored_image,
                        global_step=0, dataformats="HWC")
            if num_saved >= self.cfg["training"]["n_tensorboard_trainimgs"]:
                break

    def _train_batchnorm(self, model, train, only_encoder=False):
        if only_encoder:
            modules = model.models["encoder"].modules()
        else:
            modules = model.modules()
        for m in modules:
            if isinstance(m, nn.BatchNorm2d):
                m.train(train)

    def train_step(self, inputs, step):
        self.model.train()
        if self.ema_model is not None:
            self.ema_model.train()

        for k, v in inputs.items():
            if torch.is_tensor(v):
                inputs[k] = v.to(self.device, non_blocking=True)

        if self.enable_unlabled_segmentation:
            unlabeled_inputs = self.unlabeled_data_loader.__next__()
            for k in unlabeled_inputs.keys():
                if "color_aug" in k or "K" in k or "inv_K" in k or "color" in k or k in ["onehot_lbl", "pseudo_depth"]:
                    # print(f"Move {k} to gpu.")
                    unlabeled_inputs[k] = unlabeled_inputs[k].to(self.device, non_blocking=True)

        self.optimizer.zero_grad()
        segmentation_loss = torch.tensor(0)
        segmentation_total_loss = torch.tensor(0)
        mono_loss = torch.tensor(0)
        feat_dist_loss = torch.tensor(0)
        mono_total_loss = torch.tensor(0)

        if self.cfg["model"].get("freeze_backbone_bn", False):
            self._train_batchnorm(self.model, False, only_encoder=True)

        with autocast(enabled=self.cfg["training"]["amp"]):
            outputs = self.model(inputs)

        # Train monodepth
        if self.cfg["training"]["monodepth_lambda"] > 0:
            for k, v in outputs.items():
                if "depth" in k or "cam_T_cam" in k:
                    outputs[k] = v.to(torch.float32)
            self.monodepth_loss_calculator_train.generate_images_pred(inputs, outputs)
            mono_losses = self.monodepth_loss_calculator_train.compute_losses(inputs, outputs)
            mono_lambda = self.cfg["training"]["monodepth_lambda"]
            mono_loss = mono_lambda * mono_losses["loss"]
            feat_dist_lambda = self.cfg["training"]["feat_dist_lambda"]
            if feat_dist_lambda > 0:
                feat_dist = torch.dist(outputs["encoder_features"], outputs["imnet_features"], p=2)
                feat_dist_loss = feat_dist_lambda * feat_dist
            mono_total_loss = mono_loss + feat_dist_loss

            self.scaler.scale(mono_total_loss).backward(retain_graph=True)

        # Train depth on pseudo-labels
        if self.cfg["training"].get("pseudo_depth_lambda", 0) > 0:
            # Crop away bottom of image with own car
            with torch.no_grad():
                depth_loss_mask = torch.ones(outputs["disp", 0].shape, device=self.device)
                depth_loss_mask[:, :, int(outputs["disp", 0].shape[2] * 0.9):, :] = 0
            pseudo_depth_loss = berhu(outputs["disp", 0], inputs["pseudo_depth"], depth_loss_mask)
            pseudo_depth_loss *= self.cfg["training"]["pseudo_depth_lambda"]
            self.scaler.scale(pseudo_depth_loss).backward(retain_graph=True)
        else:
            pseudo_depth_loss = torch.tensor(0)

        # Train segmentation
        if self.cfg["training"]["segmentation_lambda"] > 0:
            with autocast(enabled=self.cfg["training"]["amp"]):
                segmentation_loss = self.loss_fn(input=outputs["semantics"], target=inputs["lbl"])
                if "intermediate_semantics" in outputs:
                    segmentation_loss += self.loss_fn(input=outputs["intermediate_semantics"],
                                                      target=inputs["lbl"])
                    segmentation_loss /= 2
                segmentation_loss *= self.cfg["training"]["segmentation_lambda"]
                segmentation_total_loss = segmentation_loss
            self.scaler.scale(segmentation_total_loss).backward()
            if self.enable_unlabled_segmentation:
                unlabeled_loss, unlabeled_mono_loss = self.train_step_segmentation_unlabeled(unlabeled_inputs, step)
                segmentation_total_loss += unlabeled_loss
                mono_total_loss += unlabeled_mono_loss

        if self.cfg["training"].get("clip_grad_norm") is not None:
            # Unscales the gradients of optimizer's assigned params in-place
            self.scaler.unscale_(self.optimizer)
            # Since the gradients of optimizer's assigned params are unscaled, clips as usual:
            if self.cfg["training"].get("disable_depth_grad_clip", False):
                torch.nn.utils.clip_grad_norm_(get_params(self.model, ["encoder", "segmentation"]),
                                               self.cfg["training"]["clip_grad_norm"])
            else:
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg["training"]["clip_grad_norm"])
        # optimizer's gradients are already unscaled, so scaler.step does not unscale them,
        # although it still skips optimizer.step() if the gradients contain infs or NaNs.
        self.scaler.step(self.optimizer)
        self.scaler.update()
        if isinstance(self.scheduler, ReduceLROnPlateau):
            self.scheduler.step(metrics=self.mIoU)
        else:
            self.scheduler.step()

        # update Mean teacher network
        if self.ema_model is not None:
            self.ema_model = self.update_ema_variables(ema_model=self.ema_model, model=self.model,
                                                       alpha_teacher=0.99, iteration=step)

        total_loss = segmentation_total_loss + mono_total_loss + pseudo_depth_loss

        return {
            'segmentation_loss': segmentation_loss.detach(),
            'mono_loss': mono_loss.detach(),
            'pseudo_depth_loss': pseudo_depth_loss.detach(),
            'feat_dist_loss': feat_dist_loss.detach(),
            'segmentation_total_loss': segmentation_total_loss.detach(),
            'mono_total_loss': mono_total_loss.detach(),
            'total_loss': total_loss.detach()
        }

    def setup_segmentation_unlabeled(self):
        if self.cfg["training"].get("unlabeled_segmentation", None) is None:
            self.enable_unlabled_segmentation = False
            return
        unlabeled_cfg = self.cfg["training"]["unlabeled_segmentation"]
        self.enable_unlabled_segmentation = True
        self.consistency_weight = unlabeled_cfg["consistency_weight"]
        self.mix_mask = unlabeled_cfg.get("mix_mask", None)
        self.unlabeled_color_jitter = unlabeled_cfg.get("color_jitter")
        self.unlabeled_blur = unlabeled_cfg.get("blur")
        self.only_unlabeled = unlabeled_cfg.get("only_unlabeled", True)
        self.only_labeled = unlabeled_cfg.get("only_labeled", False)
        self.mix_video = unlabeled_cfg.get("mix_video", False)
        assert not (self.only_unlabeled and self.only_labeled)
        self.mix_use_gt = unlabeled_cfg.get("mix_use_gt", False)
        self.unlabeled_debug_imgs = unlabeled_cfg.get("debug_images", False)
        self.depthcomp_margin = unlabeled_cfg["depthcomp_margin"]
        self.depthcomp_foreground_threshold = unlabeled_cfg["depthcomp_foreground_threshold"]
        self.unlabeled_backward_first_pseudo_label = unlabeled_cfg["backward_first_pseudo_label"]
        self.depthmix_online_depth = unlabeled_cfg.get("depthmix_online_depth", False)

    def generate_mix_mask(self, mode, argmax_u_w, unlabeled_imgs, depths):
        if mode == "class":
            for image_i in range(self.cfg["training"]["batch_size"]):
                classes = torch.unique(argmax_u_w[image_i])
                classes = classes[classes != 250]
                nclasses = classes.shape[0]
                classes = (classes[torch.Tensor(
                    np.random.choice(nclasses, int((nclasses - nclasses % 2) / 2), replace=False)).long()]).cuda()
                if image_i == 0:
                    MixMask = transformmasks.generate_class_mask(argmax_u_w[image_i], classes).unsqueeze(0).cuda()
                else:
                    MixMask = torch.cat(
                        (MixMask, transformmasks.generate_class_mask(argmax_u_w[image_i], classes).unsqueeze(0).cuda()))
        elif self.mix_mask == "depthcomp":
            assert self.cfg["training"]["batch_size"] == 2
            for image_i, other_image_i in [(0, 1), (1, 0)]:
                own_disp = depths[image_i]
                other_disp = depths[other_image_i]
                # Margin avoids too much of mixing road with same depth
                foreground_mask = torch.ge(own_disp, other_disp - self.depthcomp_margin).long()
                # Avoid hiding the real background of the other image with own a bit closer background
                if isinstance(self.depthcomp_foreground_threshold, tuple) or isinstance(
                        self.depthcomp_foreground_threshold, list):
                    ft_l, ft_u = self.depthcomp_foreground_threshold
                    assert ft_u > ft_l
                    ft = torch.rand(1, device=own_disp.device) * (ft_u - ft_l) + ft_l
                else:
                    ft = self.depthcomp_foreground_threshold
                foreground_mask *= torch.ge(own_disp, ft).long()
                if image_i == 0:
                    MixMask = foreground_mask
                else:
                    MixMask = torch.cat((MixMask, foreground_mask))
        elif mode == "depth":
            for image_i in range(self.cfg["training"]["batch_size"]):
                generated_depth = depths[image_i]
                min_depth = 0.1
                max_depth = 0.4
                depth_threshold = torch.rand(1, device=depths.device) * (max_depth - min_depth) + min_depth
                if image_i == 0:
                    MixMask = transformmasks.generate_depth_mask(generated_depth, depth_threshold).cuda()
                else:
                    MixMask = torch.cat(
                        (MixMask, transformmasks.generate_depth_mask(generated_depth, depth_threshold).cuda()))
        elif mode == "depthhist":
            for image_i in range(self.cfg["training"]["batch_size"]):
                generated_depth = depths[image_i]
                hist, bin_edges = np.histogram(torch.log(1 + generated_depth).flatten(), bins=100, density=True)
                # Exclude the first bin as it sometimes has a meaningless peak
                for v, e in zip(np.flip(hist)[1:], np.flip(bin_edges)[1:]):
                    if v > 1.5:
                        max_depth = torch.tensor([e])
                        break

                hist = np.cumsum(hist) / np.sum(hist)
                for v, e in zip(hist, bin_edges):
                    if v > 0.4:
                        min_depth = torch.tensor([e])
                        break
                depth_threshold = torch.rand(1) * (max_depth - min_depth) + min_depth
                if image_i == 0:
                    MixMask = transformmasks.generate_depth_mask(generated_depth, depth_threshold).cuda()
                else:
                    MixMask = torch.cat(
                        (MixMask, transformmasks.generate_depth_mask(generated_depth, depth_threshold).cuda()))
        elif mode is None:
            MixMask = torch.ones((unlabeled_imgs.shape[0], *unlabeled_imgs.shape[2:]), device=self.device)
        else:
            raise NotImplementedError(f"Unknown mix_mask {self.mix_mask}")

        return MixMask

    def calc_pseudo_label_loss(self, teacher_softmax, student_logits):
        max_probs, pseudo_label = torch.max(teacher_softmax, dim=1)
        pseudo_label[max_probs == 0] = self.unlabeled_loader.ignore_index
        unlabeled_weight = torch.sum(max_probs.ge(0.968).long() == 1).item() / np.prod(pseudo_label.shape)
        pixelWiseWeight = unlabeled_weight * torch.ones(max_probs.shape, device=self.device)
        L_u = self.consistency_weight * cross_entropy2d(input=student_logits, target=pseudo_label,
                                                        pixel_weights=pixelWiseWeight)
        return L_u, pseudo_label

    def train_step_segmentation_unlabeled(self, unlabeled_inputs, step):
        def strongTransform(parameters, data=None, target=None):
            assert ((data is not None) or (target is not None))
            data, target = transformsgpu.mix(mask=parameters["Mix"], data=data, target=target)
            data, target = transformsgpu.color_jitter(jitter=parameters["ColorJitter"], data=data, target=target)
            data, target = transformsgpu.gaussian_blur(blur=parameters["GaussianBlur"], data=data, target=None)
            return data, target

        unlabeled_imgs = unlabeled_inputs[("color_aug", 0, 0)]

        # First Step: Run teacher to generate pseudo labels
        self.ema_model.use_pose_net = False
        logits_u_w = self.ema_model(unlabeled_inputs)["semantics"]
        softmax_u_w = torch.softmax(logits_u_w.detach(), dim=1)
        if self.mix_use_gt:
            with torch.no_grad():
                for i in range(unlabeled_imgs.shape[0]):
                    # .data is necessary to access truth value of tensor
                    if unlabeled_inputs["is_labeled"][i].data:
                        softmax_u_w[i] = unlabeled_inputs["onehot_lbl"][i]
        _, argmax_u_w = torch.max(softmax_u_w, dim=1)

        # Second Step: Run student network on unaugmented data to generate depth for DepthMix, calculate monodepth loss,
        # and unaugmented segmentation pseudo label loss
        mono_loss = 0
        L_1 = 0
        if self.depthmix_online_depth:
            outputs_1 = self.model(unlabeled_inputs)
            if self.cfg["training"]["monodepth_lambda"] > 0:
                self.monodepth_loss_calculator_train.generate_images_pred(unlabeled_inputs, outputs_1)
                mono_losses = self.monodepth_loss_calculator_train.compute_losses(unlabeled_inputs, outputs_1)
                mono_lambda = self.cfg["training"]["monodepth_lambda"]
                mono_loss = mono_lambda * mono_losses["loss"]
                self.scaler.scale(mono_loss).backward(retain_graph=self.unlabeled_backward_first_pseudo_label)
                depths = outputs_1[("disp", 0)].detach()
                for j in range(depths.shape[0]):
                    dmin = torch.min(depths[j])
                    dmax = torch.max(depths[j])
                    depths[j] = torch.clamp(depths[j], dmin, dmax)
                    depths[j] = (depths[j] - dmin) / (dmax - dmin)
            else:
                depths = unlabeled_inputs["pseudo_depth"]
            if self.unlabeled_backward_first_pseudo_label:
                logits_1 = outputs_1["semantics"]
                L_1, _ = self.calc_pseudo_label_loss(teacher_softmax=softmax_u_w, student_logits=logits_1)
                self.scaler.scale(L_1).backward()
        elif "pseudo_depth" in unlabeled_inputs:
            depths = unlabeled_inputs["pseudo_depth"]
        else:
            depths = [None] * unlabeled_imgs.shape[0]

        # Third Step: Run Mix
        MixMask = self.generate_mix_mask(self.mix_mask, argmax_u_w, unlabeled_imgs, depths)

        strong_parameters = {"Mix": MixMask}
        if self.unlabeled_color_jitter:
            strong_parameters["ColorJitter"] = random.uniform(0, 1)
        else:
            strong_parameters["ColorJitter"] = 0
        if self.unlabeled_blur:
            strong_parameters["GaussianBlur"] = random.uniform(0, 1)
        else:
            strong_parameters["GaussianBlur"] = 0

        inputs_u_s, _ = strongTransform(strong_parameters, data=unlabeled_imgs)
        unlabeled_inputs[("color_aug", 0, 0)] = inputs_u_s
        outputs = self.model(unlabeled_inputs)
        logits_u_s = outputs["semantics"]

        softmax_u_w_mixed, _ = strongTransform(strong_parameters, data=softmax_u_w)
        L_2, pseudo_label = self.calc_pseudo_label_loss(teacher_softmax=softmax_u_w_mixed, student_logits=logits_u_s)
        self.scaler.scale(L_2).backward()

        for j, (f, img, ps_lab, mask, d) in enumerate(
                zip(unlabeled_inputs["filename"], inputs_u_s, pseudo_label, MixMask, depths)):
            if (step + 1) % self.cfg["training"]["print_interval"] != 0:
                continue
            fn = f"{self.cfg['training']['log_path']}/class_mix_debug/{step}_{j}_img.jpg"
            os.makedirs(os.path.dirname(fn), exist_ok=True)
            rows, cols = 2, 2
            fig, axs = plt.subplots(rows, cols, sharex='col', sharey='row',
                                    gridspec_kw={'hspace': 0, 'wspace': 0},
                                    figsize=(4 * cols, 4 * rows))
            axs[0][0].imshow(img.permute(1, 2, 0).cpu().numpy())
            axs[0][1].imshow(mask.float().cpu().numpy(), cmap="gray")
            if d is not None:
                axs[1][1].imshow(d[0].cpu().numpy(), cmap="plasma")
            axs[1][0].imshow(self.val_loader.decode_segmap_tocolor(ps_lab.cpu().numpy()))
            for ax in axs.flat:
                ax.axis("off")
            plt.savefig(fn)
            plt.close()

        return L_2 + L_1, mono_loss

    def train(self):
        self.start_iter = 0
        self.best_iou = -100.0
        if self.cfg["training"]["resume"] is not None:
            self.load_resume()
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = self.cfg["training"]["optimizer"]["lr"]

        train_loss_meter = AverageMeterDict()
        time_meter = AverageMeter()

        step = self.start_iter
        flag = True

        self.tensorboard_training_images()

        start_ts = time.time()
        while step <= self.cfg["training"]["train_iters"] and flag:
            for inputs in self.train_data_loader:
                # torch.cuda.empty_cache()
                step += 1
                losses = self.train_step(inputs, step)

                time_meter.update(time.time() - start_ts)
                train_loss_meter.update(losses)

                if (step + 1) % self.cfg["training"]["print_interval"] == 0:
                    fmt_str = "Iter [{}/{}]  Loss: {:.4f}  Time/Image: {:.4f}"
                    print_str = fmt_str.format(
                        step + 1,
                        self.cfg["training"]["train_iters"],
                        train_loss_meter.avgs["total_loss"],
                        time_meter.avg / self.cfg["training"]["batch_size"],
                    )

                    self.logger.info(print_str)
                    for k, v in train_loss_meter.avgs.items():
                        self.writer.add_scalar("training/" + k, v, step + 1)
                    self.writer.add_scalar("training/learning_rate", get_lr(self.optimizer), step + 1)
                    self.writer.add_scalar("training/time_per_image",
                                           time_meter.avg / self.cfg["training"]["batch_size"], step + 1)
                    self.writer.add_scalar("training/amp_scale", self.scaler.get_scale(), step + 1)
                    self.writer.add_scalar("training/memory", psutil.virtual_memory().used / 1e9, step + 1)
                    time_meter.reset()
                    train_loss_meter.reset()

                if (step + 1) % current_val_interval(self.cfg, step + 1) == 0 or (step + 1) == self.cfg["training"][
                    "train_iters"
                ]:
                    self.validate(step)

                    if self.mIoU >= self.best_iou:
                        self.best_iou = self.mIoU
                        if self.cfg["training"]["save_model"]:
                            self.save_resume(step)

                    if self.earlyStopping is not None:
                        if not self.earlyStopping.step(self.mIoU):
                            flag = False
                            break

                if (step + 1) == self.cfg["training"]["train_iters"]:
                    flag = False
                    break

                start_ts = time.time()

        return step

    def validate(self, step):
        self.model.eval()
        val_loss_meter = AverageMeterDict()
        running_metrics_val = runningScore(self.n_classes)
        imgs_to_save = []
        with torch.no_grad():
            for inputs_val in tqdm(self.val_data_loader,
                                   total=len(self.val_data_loader)):
                if self.cfg["model"]["disable_monodepth"]:
                    required_inputs = [("color_aug", 0, 0), "lbl"]
                else:
                    required_inputs = inputs_val.keys()
                for k, v in inputs_val.items():
                    if torch.is_tensor(v) and k in required_inputs:
                        inputs_val[k] = v.to(self.device, non_blocking=True)
                images_val = inputs_val[("color_aug", 0, 0)]
                with autocast(enabled=self.cfg["training"]["amp"]):
                    outputs = self.model(inputs_val)

                if self.cfg["training"]["segmentation_lambda"] > 0:
                    labels_val = inputs_val["lbl"]
                    semantics = outputs["semantics"]
                    val_segmentation_loss = self.loss_fn(input=semantics, target=labels_val)
                    # Handle inconsistent size between input and target
                    n, c, h, w = semantics.size()
                    nt, ht, wt = labels_val.size()
                    if h != ht and w != wt:  # upsample labels
                        semantics = F.interpolate(
                            semantics, size=(ht, wt),
                            mode="bilinear", align_corners=True
                        )
                    pred = semantics.data.max(1)[1].cpu().numpy()
                    gt = labels_val.data.cpu().numpy()

                    running_metrics_val.update(gt, pred)
                else:
                    pred = [None] * images_val.shape[0]
                    gt = [None] * images_val.shape[0]
                    val_segmentation_loss = torch.tensor(0)

                if not self.cfg["model"]["disable_monodepth"]:
                    if not self.cfg["model"]["disable_pose"]:
                        self.monodepth_loss_calculator_val.generate_images_pred(inputs_val, outputs)
                        mono_losses = self.monodepth_loss_calculator_val.compute_losses(inputs_val, outputs)
                        val_mono_loss = mono_losses["loss"]
                    else:
                        outputs.update(self.model.predict_test_disp(inputs_val))
                        self.monodepth_loss_calculator_val.generate_depth_test_pred(outputs)
                        val_mono_loss = torch.tensor(0)
                else:
                    outputs[("disp", 0)] = [None] * images_val.shape[0]
                    val_mono_loss = torch.tensor(0)

                if self.cfg["data"].get("depth_teacher", None) is not None:
                    # Crop away bottom of image with own car
                    with torch.no_grad():
                        depth_loss_mask = torch.ones(outputs["disp", 0].shape, device=self.device)
                        depth_loss_mask[:, :, int(outputs["disp", 0].shape[2] * 0.9):, :] = 0
                    val_pseudo_depth_loss = berhu(outputs["disp", 0], inputs_val["pseudo_depth"], depth_loss_mask,
                                              apply_log=self.cfg["training"].get("pseudo_depth_loss_log", False))
                else:
                    val_pseudo_depth_loss = torch.tensor(0)

                val_loss_meter.update({
                    "segmentation_loss": val_segmentation_loss.detach(),
                    "monodepth_loss": val_mono_loss.detach(),
                    "pseudo_depth_loss": val_pseudo_depth_loss.detach()
                })

                for img, label, output, depth in zip(images_val, gt, pred, outputs[("disp", 0)]):
                    if len(imgs_to_save) < self.cfg["training"]["n_tensorboard_imgs"]:
                        imgs_to_save.append([
                            img, label, output,
                            depth if depth is None else depth.detach()])

        for k, v in val_loss_meter.avgs.items():
            self.writer.add_scalar("validation/" + k, v, step + 1)
        if self.cfg["training"]["segmentation_lambda"] > 0:
            score, class_iou = running_metrics_val.get_scores()
            for k, v in score.items():
                print(k, v)
                self.writer.add_scalar("val_metrics/{}".format(k), v, step + 1)
            for k, v in class_iou.items():
                self.writer.add_scalar("val_metrics/cls_{}".format(k), v, step + 1)
            self.mIoU = score["Mean IoU : \t"]
            self.fwAcc = score["FreqW Acc : \t"]

        for j, imgs in enumerate(imgs_to_save):
            # Only log the first image as they won't change -> save memory
            if (step + 1) // current_val_interval(self.cfg, step + 1) == 1:
                self.img_writer.add_image(
                    "{}/{}_0image".format(self.run_id.replace('/', '_'), j), imgs[0], global_step=step + 1)
                if imgs[1] is not None:
                    colored_image = self.val_loader.decode_segmap_tocolor(imgs[1])
                    self.img_writer.add_image(
                        "{}/{}_1ground_truth".format(self.run_id.replace('/', '_'), j), colored_image,
                        global_step=step + 1, dataformats="HWC")
            if imgs[2] is not None:
                colored_image = self.val_loader.decode_segmap_tocolor(imgs[2])
                self.img_writer.add_image(
                    "{}/{}_2prediction".format(self.run_id.replace('/', '_'), j), colored_image, global_step=step + 1,
                    dataformats="HWC")
            if imgs[3] is not None:
                colored_image = _colorize(imgs[3], "plasma", max_percentile=100)
                self.img_writer.add_image(
                    "{}/{}_3depth".format(self.run_id.replace('/', '_'), j), colored_image, global_step=step + 1,
                    dataformats="HWC")
示例#10
0
class PyTorchTrainer:
    def __init__(self, model, device, config, fold_num):
        self.config = config
        self.epoch = 0
        self.start_epoch = 0
        self.fold_num = fold_num
        if self.config.stage2:
            self.base_dir = f'./result/stage2/{config.dir}/{config.dir}_fold_{config.fold_num}'
        else:
            self.base_dir = f'./result/{config.dir}/{config.dir}_fold_{config.fold_num}'
        os.makedirs(self.base_dir, exist_ok=True)
        self.log_path = f'{self.base_dir}/log.txt'
        self.best_summary_loss = 10**5

        self.model = model
        self.swa_model = AveragedModel(self.model)
        self.device = device
        self.wandb = True

        self.cutmix = self.config.cutmix_ratio
        self.fmix = self.config.fmix_ratio
        self.smix = self.config.smix_ratio

        self.es = EarlyStopping(patience=8)

        self.scaler = GradScaler()
        self.amp = self.config.amp
        param_optimizer = list(self.model.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.001
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0
        }]

        self.optimizer, self.scheduler = get_optimizer(
            self.model, self.config.optimizer_name,
            self.config.optimizer_params, self.config.scheduler_name,
            self.config.scheduler_params, self.config.n_epochs)

        self.criterion = get_criterion(self.config.criterion_name,
                                       self.config.criterion_params)
        self.log(f'Fitter prepared. Device is {self.device}')
        set_wandb(self.config, fold_num)

    def fit(self, train_loader, validation_loader):
        if self.config.FIRST_FREEZE:
            self.model.freeze()

        for e in range(self.start_epoch, self.config.n_epochs):
            if self.config.verbose:
                lr = self.optimizer.param_groups[0]['lr']
                timestamp = datetime.utcnow().isoformat()
                self.log(f'\n{timestamp}\nLR: {lr}')
                wandb.log({"Epoch": self.epoch, "lr": lr}, step=e)

            if self.config.step_scheduler:
                self.scheduler.step(e)

            if e >= self.config.START_FREEZE and self.config.FREEZE:
                print('Model Frozen -> Train Classifier Only')
                self.model.freeze()
                self.config.FREEZE = False

            if e >= self.config.END_FREEZE and self.config.FIRST_FREEZE:
                print('Model UnFrozen -> Train Classifier Only')
                self.model.unfreeze()
                self.config.FIRST_FREEZE = False

            t = time.time()
            summary_loss, summary_scores, example_images = self.train_one_epoch(
                train_loader)
            torch.cuda.empty_cache()
            self.log(
                f'[RESULT]: Train. Epoch: {self.epoch}, Fold Num: {self.fold_num}, summary_loss: {summary_loss.avg:.5f}, summary_acc: {summary_scores.avg},  time: {(time.time() - t):.5f}'
            )
            self.save(
                f'{self.base_dir}/{self.config.dir}_fold_{self.fold_num}_last-checkpoint.bin'
            )
            wandb.log(
                {
                    f"Train_loss": summary_loss.avg,
                    f"Train_ACC": summary_scores.avg,
                    f"Example_{self.config.fold_num}": example_images
                },
                step=e)

            t = time.time()
            summary_loss, summary_scores = self.validation(validation_loader)
            torch.cuda.empty_cache()
            self.log(
                f'[RESULT]: Val. Epoch: {self.epoch}, summary_loss: {summary_loss.avg:.5f}, summary_acc: {summary_scores.avg},  time: {(time.time() - t):.5f}'
            )

            # if summary_loss.avg < self.best_summary_loss:
            self.best_summary_loss = summary_loss.avg
            self.model.eval()
            self.save(
                f'{self.base_dir}/{self.config.dir}_fold_{self.config.fold_num}_best-checkpoint-{str(self.epoch).zfill(3)}epoch.bin'
            )
            # for path in sorted(glob(f'{self.base_dir}/{self.config.dir}_fold_{self.config.fold_num}_best-checkpoint-*epoch.bin'))[:-3]:
            #     os.remove(path)

            if self.config.validation_scheduler:
                self.scheduler.step(metrics=summary_loss.avg)

            self.epoch += 1

    def validation(self, val_loader):
        self.model.eval()
        summary_loss = AverageMeter()
        summary_acc = AverageMeter()

        t = time.time()

        y_true = []
        y_pred = []
        for step, (images, targets) in enumerate(val_loader):
            if self.config.verbose:
                if step % self.config.verbose_step == 0:
                    print(
                        f'Val Step {step}/{len(val_loader)}, ' + \
                        f'summary_loss: {summary_loss.avg:.5f}, ' + \
                        f'time: {(time.time() - t):.5f}', end='\r'
                    )

            with torch.no_grad():
                targets = targets.to(self.device).float()
                batch_size = images.shape[0]
                images = images.to(self.device).float()
                _, outputs = self.model(images)
                loss = self.criterion(outputs, targets)

                # targets = targets.argmax(1)
                y_true.extend(targets.argmax(1).detach().cpu().numpy())
                y_pred.extend(outputs.argmax(1).detach().cpu().numpy())
                summary_loss.update(loss.detach().item(), batch_size)
                summary_acc.update(
                    (outputs.argmax(1) == targets.argmax(1)).sum().item() /
                    batch_size, batch_size)

        wandb.log(
            {
                f"Val_loss": summary_loss.avg,
                f"Val_ACC": summary_acc.avg,
            },
            step=self.epoch)

        if self.es.step(torch.tensor(summary_loss.avg)):
            self.log("Stop Early Stopiing")
            plot_confusion_matrix(y_true, y_pred)
            exit(0)

        if self.epoch == self.config.n_epochs - 1:
            plot_confusion_matrix(y_true, y_pred)
        return summary_loss, summary_acc

    def train_one_epoch(self, train_loader):
        self.model.train()
        if self.epoch < self.config.freeze_bn_epoch:
            self.model.freeze_batchnorm_stats()

        summary_loss = AverageMeter()
        summary_acc = AverageMeter()

        example_images = []

        t = time.time()
        for step, (images, targets) in enumerate(train_loader):
            choice = np.random.rand(1)
            self.optimizer.zero_grad()
            if self.config.verbose:
                if step % self.config.verbose_step == 0:
                    print(
                        f'Train Step {step}/{len(train_loader)}, ' + \
                        f'summary_loss: {summary_loss.avg:.5f}, ' + \
                        f'time: {(time.time() - t):.5f}', end='\r'
                    )
            targets = targets.to(self.device).float()
            images = images.to(self.device).float()
            batch_size = images.shape[0]
            if self.config.FIRST_FREEZE and self.config.END_FREEZE > self.epoch:
                if self.amp:
                    with autocast():
                        _, outputs = self.model(images)
                        loss = self.criterion(outputs, targets)

                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        self.model.parameters(), 1000)
                    self.scaler.scale(loss).backward()
                    self.scaler.step(self.optimizer)
                    self.scaler.update()

            else:
                if self.amp:
                    with autocast():
                        if choice < self.cutmix:
                            aug_images, aug_targets = cutmix(
                                images, targets, 1.)
                            _, outputs = self.model(aug_images)
                            loss = mix_criterion(outputs, aug_targets,
                                                 self.criterion)
                        elif choice < self.cutmix + self.fmix:
                            aug_images, aug_targets = fmix(
                                images,
                                targets,
                                alpha=1.,
                                decay_power=3.,
                                shape=self.config.img_size,
                                device=device)
                            aug_images = aug_images.to(self.device).float()
                            _, outputs = self.model(aug_images)
                            loss = mix_criterion(outputs, aug_targets,
                                                 self.criterion)
                        elif choice < self.cutmix + self.fmix + self.smix:
                            X, ya, yb, lam_a, lam_b = snapmix(images,
                                                              targets,
                                                              alpha=0.5,
                                                              model=self.model)
                            _, outputs, _ = self.model(X)
                            loss = self.snapmix_criterion(
                                self.criterion, outputs, ya, yb, lam_a, lam_b)
                        else:
                            _, outputs = self.model(images)
                            loss = self.criterion(outputs, targets)

                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        self.model.parameters(), 1000)
                    self.scaler.scale(loss).backward()
                    self.scaler.step(self.optimizer)
                    self.scaler.update()

                else:
                    if choice < self.cutmix:
                        aug_images, aug_targets = cutmix(images, targets, 1.)
                        _, outputs = self.model(aug_images)
                        loss = mix_criterion(outputs, aug_targets,
                                             self.criterion)
                    elif choice < self.cutmix + self.fmix:
                        aug_images, aug_targets = fmix(
                            images,
                            targets,
                            alpha=1.,
                            decay_power=3.,
                            shape=self.config.img_size,
                            device=device)
                        aug_images = aug_images.to(self.device).float()
                        _, outputs = self.model(aug_images)
                        loss = mix_criterion(outputs, aug_targets,
                                             self.criterion)
                    elif choice < self.cutmix + self.fmix + self.smix:
                        X, ya, yb, lam_a, lam_b = snapmix(images,
                                                          targets,
                                                          alpha=0.5,
                                                          model=self.model)
                        _, outputs, _ = self.model(X)
                        loss = self.snapmix_criterion(self.criterion, outputs,
                                                      ya, yb, lam_a, lam_b)
                    else:
                        _, outputs = self.model(images)
                        loss = self.criterion(outputs, targets)

                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        self.model.parameters(), 1000)
                    loss = self.criterion(outputs, targets)
                    loss.backward()
                    self.optimizer.step()

            if len(example_images) < 16:
                example_images.append(
                    wandb.Image(
                        images[
                            0],  # caption=f"Truth: {targets[0].argmax(1).detach().cpu().item()}"
                    ))

            summary_loss.update(loss.detach().item(), batch_size)
            summary_acc.update(
                (outputs.argmax(1) == targets.argmax(1)).sum().item() /
                batch_size, batch_size)

        return summary_loss, summary_acc, example_images

    def predict(self, test_loader, sub):
        self.model.eval()
        all_outputs = torch.tensor([], device=self.device)
        for step, (images, fnames) in enumerate(test_loader):
            with torch.no_grad():
                images = images.to(self.device).float()
                outputs = self.model.forward(images)
                all_outputs = torch.cat((all_outputs, outputs), 0)

        sub.iloc[:, 1] = all_outputs.detach().cpu().numpy()
        return sub

    def save(self, path):
        self.model.eval()
        torch.save(
            {
                'model_state_dict': self.model.state_dict(),
                # 'optimizer_state_dict': self.optimizer.state_dict(),
                # 'scheduler_state_dict': self.scheduler.state_dict(),
                'best_summary_loss': self.best_summary_loss,
                'epoch': self.epoch,
            },
            path)

        wandb.save(path.split("/")[-1])

    def load(self, path):
        checkpoint = torch.load(path)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        # self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        # self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        self.best_summary_loss = checkpoint['best_summary_loss']
        self.epoch = checkpoint['epoch'] + 1
        self.start_epoch = checkpoint['epoch'] + 1

    def log(self, message):
        if self.config.verbose:
            print(message)
        with open(self.log_path, 'a+') as logger:
            logger.write(f'{message}\n')
示例#11
0
def train(X_train, y_train, X_dev, y_dev, X_test, y_test):
    num_labels = NUM_EMO

    vocab_size = VOCAB_SIZE

    print('NUM of VOCAB' + str(vocab_size))
    train_data = EmotionDataLoader(X_train, y_train, PAD_LEN)
    train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)

    dev_data = EmotionDataLoader(X_dev, y_dev, PAD_LEN)
    dev_loader = DataLoader(dev_data,
                            batch_size=int(BATCH_SIZE / 3) + 2,
                            shuffle=False)

    test_data = EmotionDataLoader(X_test, y_test, PAD_LEN)
    test_loader = DataLoader(test_data,
                             batch_size=int(BATCH_SIZE / 3) + 2,
                             shuffle=False)

    model = AttentionLSTMClassifier(EMBEDDING_DIM,
                                    HIDDEN_DIM,
                                    vocab_size,
                                    num_labels,
                                    BATCH_SIZE,
                                    att_mode=opt.attention,
                                    use_glove=USE_GLOVE)

    if USE_GLOVE:
        model.load_embedding(tokenizer.get_embeddings())
    # multi-GPU
    # model = nn.DataParallel(model)
    model.cuda()

    if opt.loss == 'ce':
        loss_criterion = nn.CrossEntropyLoss()  #
    elif opt.loss == 'focal':
        loss_criterion = FocalLoss(gamma=2, reduce=True)
    else:
        raise Exception('loss option not recognised')

    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    es = EarlyStopping(patience=PATIENCE)
    for epoch in range(1, 300):
        print('Epoch: ' + str(epoch) + '===================================')
        train_loss = 0
        model.train()
        for i, (data, seq_len,
                label) in tqdm(enumerate(train_loader),
                               total=len(train_data) / BATCH_SIZE):
            optimizer.zero_grad()
            data_text = [tokenizer.decode_ids(x) for x in data]
            with torch.no_grad():
                character_ids = batch_to_ids(data_text).cuda()
                elmo_emb = elmo(character_ids)['elmo_representations']
                elmo_emb = (elmo_emb[0] + elmo_emb[1]) / 2  # avg of two layers

            y_pred = model(data.cuda(), seq_len, elmo_emb)
            loss = loss_criterion(y_pred, label.view(-1).cuda())
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), CLIPS)
            optimizer.step()
            train_loss += loss.data.cpu().numpy() * data.shape[0]
            del y_pred, loss

        test_loss = 0
        model.eval()
        for _, (_data, _seq_len, _label) in enumerate(dev_loader):
            with torch.no_grad():

                data_text = [tokenizer.decode_ids(x) for x in _data]
                character_ids = batch_to_ids(data_text).cuda()
                elmo_emb = elmo(character_ids)['elmo_representations']
                elmo_emb = (elmo_emb[0] + elmo_emb[1]) / 2  # avg of two layers

                y_pred = model(_data.cuda(), _seq_len, elmo_emb)
                loss = loss_criterion(y_pred, _label.view(-1).cuda())
                test_loss += loss.data.cpu().numpy() * _data.shape[0]
                del y_pred, loss

        print("Train Loss: " + str(train_loss / len(train_data)) + \
              " Evaluation: " + str(test_loss / len(dev_data)))

        if es.step(test_loss):
            print('over fitting!')
            break

    with open(f'lstm_elmo_{opt.dataset}_model.pt', 'bw') as f:
        torch.save(model.state_dict(), f)

    pred_list = []
    model.eval()
    for _, (_data, _seq_len, _label) in enumerate(test_loader):
        with torch.no_grad():
            data_text = [tokenizer.decode_ids(x) for x in _data]
            character_ids = batch_to_ids(data_text).cuda()
            elmo_emb = elmo(character_ids)['elmo_representations']
            elmo_emb = (elmo_emb[0] + elmo_emb[1]) / 2  # avg of two layers

            y_pred = model(_data.cuda(), _seq_len, elmo_emb)
            pred_list.append(
                y_pred.data.cpu().numpy())  # x[np.where( x > 3.0 )]
            del y_pred

    pred_list = np.argmax(np.concatenate(pred_list, axis=0), axis=1)

    return pred_list