Example #1
0
def train(params, m, datas):
    # early stopping
    es = EarlyStopping(mode='max', patience=params.cldc_patience)
    # set optimizer
    optimizer = get_optimizer(params, m)

    # training on one lang, and dev/test for another lang
    # get training
    train_lang, train_data = get_lang_data(params, datas, training=True)
    # get dev and test, dev is the same language as test
    test_lang, test_data = get_lang_data(params, datas)

    n_batch = train_data.train_size // params.cldc_bs if train_data.train_size % params.cldc_bs == 0 else train_data.train_size // params.cldc_bs + 1
    # per category
    data_idxs = [
        list(range(len(train_idx))) for train_idx in train_data.train_idxs
    ]

    # number of iterations
    cur_it = 0
    # write to tensorboard
    writer = SummaryWriter('./history/{}'.format(
        params.log_path)) if params.write_tfboard else None
    # best xx
    bdev = 0
    btest = 0
    # current xx
    cdev = 0
    ctest = 0
    dev_class_acc = {}
    test_class_acc = {}
    dev_cm = None
    test_cm = None
    # early stopping warm up flag, start es after some iters
    es_flag = False

    for i in range(params.cldc_ep):
        for data_idx in data_idxs:
            shuffle(data_idx)
        for j in range(n_batch):
            train_idxs = []
            for k, data_idx in enumerate(data_idxs):
                if j < n_batch - 1:
                    train_idxs.append(
                        data_idx[int(j * params.cldc_bs *
                                     train_data.train_prop[k]):int(
                                         (j + 1) * params.cldc_bs *
                                         train_data.train_prop[k])])
                elif j == n_batch - 1:
                    train_idxs.append(data_idx[int(j * params.cldc_bs *
                                                   train_data.train_prop[k]):])

            batch_train, batch_train_lens, batch_train_lb = get_batch(
                params, train_idxs, train_data.train_idxs,
                train_data.train_lens)
            optimizer.zero_grad()
            m.train()

            cldc_loss_batch, _, batch_pred = m(train_lang, batch_train,
                                               batch_train_lens,
                                               batch_train_lb)

            batch_acc, batch_acc_cls = get_classification_report(
                params,
                batch_train_lb.data.cpu().numpy(),
                batch_pred.data.cpu().numpy())

            if cldc_loss_batch < params.cldc_lossth:
                es_flag = True

            cldc_loss_batch.backward()
            out_cldc(i, j, n_batch, cldc_loss_batch, batch_acc, batch_acc_cls,
                     bdev, btest, cdev, ctest, es.num_bad_epochs)

            optimizer.step()
            cur_it += 1
            update_tensorboard(writer, cldc_loss_batch, batch_acc, cdev, ctest,
                               dev_class_acc, test_class_acc, cur_it)

            if cur_it % params.CLDC_VAL_EVERY == 0:
                sys.stdout.write('\n')
                sys.stdout.flush()
                # validation
                #cdev, dev_class_acc, dev_cm = test(params, m, test_data.dev_idxs, test_data.dev_lens, test_data.dev_size, test_data.dev_prop, test_lang, cm = True)
                cdev, dev_class_acc, dev_cm = test(params,
                                                   m,
                                                   train_data.dev_idxs,
                                                   train_data.dev_lens,
                                                   train_data.dev_size,
                                                   train_data.dev_prop,
                                                   train_lang,
                                                   cm=True)
                ctest, test_class_acc, test_cm = test(params,
                                                      m,
                                                      test_data.test_idxs,
                                                      test_data.test_lens,
                                                      test_data.test_size,
                                                      test_data.test_prop,
                                                      test_lang,
                                                      cm=True)
                print(dev_cm)
                print(test_cm)
                if es.step(cdev):
                    print('\nEarly Stoped.')
                    return
                elif es.is_better(cdev, bdev):
                    bdev = cdev
                    btest = ctest
                    #save_model(params, m)
                # reset bad epochs
                if not es_flag:
                    es.num_bad_epochs = 0
Example #2
0
def train(params, m, datas):
  # early stopping
  es = EarlyStopping(mode = 'max', patience = params.cldc_patience)
  # set optimizer
  optimizer = get_optimizer(params, m)
  # get initial parameters
  if params.zs_reg_alpha > 0:
    init_param_dict = {k: v.detach().clone() for k, v in m.named_parameters() if v.requires_grad}

  # training
  train_lang, train_data = get_lang_data(params, datas, training = True)
  # dev & test are in the same lang
  test_lang, test_data = get_lang_data(params, datas)

  n_batch = train_data.train_size // params.cldc_bs if train_data.train_size % params.cldc_bs == 0 else train_data.train_size // params.cldc_bs + 1
  # get the same n_batch for unlabelled data as well
  # batch size for unlabelled data
  rest_cldc_bs = train_data.rest_train_size // n_batch
  # per category
  data_idxs = [list(range(len(train_idx))) for train_idx in train_data.train_idxs]
  rest_data_idxs = list(range(len(train_data.rest_train_idxs)))
 
  # number of iterations
  cur_it = 0
  # write to tensorboard
  writer = SummaryWriter('./history/{}'.format(params.log_path)) if params.write_tfboard else None
  # best dev/test
  bdev = 0
  btest = 0
  # current dev/test
  cdev = 0
  ctest = 0
  dev_class_acc = {}
  test_class_acc = {}
  dev_cm = None
  test_cm = None
  # early stopping warm up flag, start es after train loss below some threshold
  es_flag = False
  # set io function
  out_semicldc = getattr(ios, 'out_semicldc_{}'.format(params.cldc_train_mode))

  for i in range(params.cldc_ep):
    for data_idx in data_idxs:
      shuffle(data_idx)
    shuffle(rest_data_idxs)
    for j in range(n_batch):
      train_idxs = []
      for k, data_idx in enumerate(data_idxs):
        if j < n_batch - 1:
          train_idxs.append(data_idx[int(j * params.cldc_bs * train_data.train_prop[k]): int((j + 1) * params.cldc_bs * train_data.train_prop[k])])
          rest_train_idxs = rest_data_idxs[j * rest_cldc_bs: (j + 1) * rest_cldc_bs]
        elif j == n_batch - 1:
          train_idxs.append(data_idx[int(j * params.cldc_bs * train_data.train_prop[k]):])
          rest_train_idxs = rest_data_idxs[j * rest_cldc_bs:]

      # get batch data
      batch_train, batch_train_lens, batch_train_lb, batch_train_ohlb = get_batch(params, train_idxs, train_data.train_idxs, train_data.train_lens) 
      batch_rest_train, batch_rest_train_lens, batch_rest_train_lb, batch_rest_train_ohlb = get_rest_batch(params, rest_train_idxs, train_data.rest_train_idxs, train_data.rest_train_lens, enumerate_discrete)

      optimizer.zero_grad()
      m.train()

      if i + 1 <= params.cldc_warm_up_ep:
        m.warm_up = True
      else:
        m.warm_up = False

      loss_dict, batch_pred = m(train_lang, 
                                batch_train, batch_train_lens, batch_train_lb, batch_train_ohlb, 
                                batch_rest_train, batch_rest_train_lens, batch_rest_train_lb, batch_rest_train_ohlb)
      # regularization term
      if params.zs_reg_alpha > 0:
        reg_loss = .0
        for k, v in m.named_parameters():
          if k in init_param_dict and v.requires_grad:
            reg_loss += torch.sum((v - init_param_dict[k]) ** 2)
        print(reg_loss.detach())
        reg_loss *= params.zs_reg_alpha / 2
        reg_loss.backward()

      batch_acc, batch_acc_cls = get_classification_report(params, batch_train_lb.data.cpu().numpy(), batch_pred.data.cpu().numpy())

      if loss_dict['L_cldc_loss'] < params.cldc_lossth:
        es_flag = True

      #loss_dict['total_loss'].backward()
      out_semicldc(i, j, n_batch, loss_dict, batch_acc, batch_acc_cls, bdev, btest, cdev, ctest, es.num_bad_epochs)
      
      #torch.nn.utils.clip_grad_norm_(filter(lambda p: p.grad is not None and p.requires_grad, m.parameters()), 5)
      '''
      # debug for gradient
      for p_name, p in m.named_parameters():
        if p.grad is not None and p.requires_grad:
          print(p_name, p.grad.data.norm(2).item())
      '''

      optimizer.step()
      cur_it += 1
      update_tensorboard(params, writer, loss_dict, batch_acc, cdev, ctest, dev_class_acc, test_class_acc, cur_it)
      
      if cur_it % params.CLDC_VAL_EVERY == 0:
        sys.stdout.write('\n') 
        sys.stdout.flush()
        # validation 
        cdev, dev_class_acc, dev_cm = test(params, m, test_data.dev_idxs, test_data.dev_lens, test_data.dev_size, test_data.dev_prop, test_lang, cm = True)
        ctest, test_class_acc, test_cm = test(params, m, test_data.test_idxs, test_data.test_lens, test_data.test_size, test_data.test_prop, test_lang, cm = True)
        print(dev_cm)
        print(test_cm)
        if es.step(cdev):
          print('\nEarly Stoped.')
          # vis
          #if params.cldc_visualize:
            #tsne2d(params, m)
          # vis
          return
        elif es.is_better(cdev, bdev):
          bdev = cdev
          btest = ctest
          #save_model(params, m)
        # reset bad epochs
        if not es_flag:
          es.num_bad_epochs = 0
def main(params, m, data):
    # early stopping
    es = EarlyStopping(mode='max', patience=params.patience)
    # set optimizer
    optimizer = get_optimizer(params, m)

    n_batch = data.train_size // params.bs if data.train_size % params.bs == 0 else data.train_size // params.bs + 1
    # per category
    data_idxs = [list(range(len(train_idx))) for train_idx in data.train_idxs]

    # number of iterations
    cur_it = 0
    # best xx
    bdev = 0
    btest = 0
    # current xx
    cdev = 0
    ctest = 0
    dev_class_acc = {}
    test_class_acc = {}
    dev_cm = None
    test_cm = None
    # early stopping warm up flag, start es after some iters
    es_flag = False

    for i in range(params.ep):
        # self-training
        if params.self_train or i >= params.semi_warm_up:
            params.self_train = True
            first_update = (i == params.semi_warm_up)
            # only for zero-shot
            if first_update:
                es.num_bad_epochs = 0
                es.best = 0
                bdev = 0
                btest = 0
            data = self_train_merge_data(params,
                                         m,
                                         es,
                                         data,
                                         first=first_update)
            n_batch = data.self_train_size // params.bs if data.self_train_size % params.bs == 0 else data.self_train_size // params.bs + 1
            # per category
            data_idxs = [
                list(range(len(train_idx)))
                for train_idx in data.self_train_idxs
            ]

        for data_idx in data_idxs:
            shuffle(data_idx)
        for j in range(n_batch):
            train_idxs = []
            for k, data_idx in enumerate(data_idxs):
                if params.self_train:
                    train_prop = data.self_train_prop
                else:
                    train_prop = data.train_prop
                if j < n_batch - 1:
                    train_idxs.append(
                        data_idx[int(j * params.bs *
                                     train_prop[k]):int((j + 1) * params.bs *
                                                        train_prop[k])])
                elif j == n_batch - 1:
                    train_idxs.append(data_idx[int(j * params.bs *
                                                   train_prop[k]):])

            if params.self_train:
                batch_train, _, batch_train_lb = get_batch(
                    params, train_idxs, data.self_train_idxs,
                    data.self_train_lens)
            else:
                batch_train, _, batch_train_lb = get_batch(
                    params, train_idxs, data.train_idxs, data.train_lens)
            optimizer.zero_grad()
            m.train()

            loss_batch, logits = m(batch_train, labels=batch_train_lb)
            batch_pred = torch.argmax(logits, dim=1)

            batch_acc, batch_acc_cls = get_classification_report(
                params,
                batch_train_lb.data.cpu().numpy(),
                batch_pred.data.cpu().numpy())

            if loss_batch < params.lossth:
                es_flag = True

            loss_batch.backward()
            out_cldc(i, j, n_batch, loss_batch, batch_acc, batch_acc_cls, bdev,
                     btest, cdev, ctest, es.num_bad_epochs)

            optimizer.step()
            cur_it += 1

        sys.stdout.write('\n')
        sys.stdout.flush()
        # validation
        cdev, dev_class_acc, dev_cm = test(params,
                                           m,
                                           data.dev_idxs,
                                           data.dev_lens,
                                           data.dev_size,
                                           data.dev_prop,
                                           cm=True)
        ctest, test_class_acc, test_cm = test(params,
                                              m,
                                              data.test_idxs,
                                              data.test_lens,
                                              data.test_size,
                                              data.test_prop,
                                              cm=True)
        print(dev_cm)
        print(test_cm)
        if es.step(cdev):
            print('\nEarly Stoped.')
            return
        elif es.is_better(cdev, bdev):
            bdev = cdev
            btest = ctest
        # reset bad epochs
        if not es_flag:
            es.num_bad_epochs = 0