def train(params, m, data_x, data_y):
  es = EarlyStopping(min_delta = params.min_delta, patience = params.patience)

  # optimizer
  optimizer = optim.Adam(filter(lambda p: p.requires_grad, m.parameters()), lr = params.init_learning_rate)
  
  n_batch = data_x.train_size // params.bs if data_x.train_size % params.bs == 0 else data_x.train_size // params.bs + 1
  data_idxs = list(range(data_x.train_size))
  
  # number of iterations
  cur_it = 0
  # write to tensorboard
  writer = SummaryWriter('./history/{}'.format(params.emb_out_path)) if params.write_tfboard else None

  nll_dev = math.inf
  best_nll_dev = math.inf
  kld_dev = math.inf

  for i in range(params.ep):
    shuffle(data_idxs)
    for j in range(n_batch):
      train_idxs = data_idxs[j * params.bs: (j + 1) * params.bs]
      # get padded & sorted batch idxs and 
      padded_batch_x, batch_x_lens = get_batch(train_idxs, data_x, data_x.train_idxs, data_x.train_lens, params.cuda)
      padded_batch_y, batch_y_lens = get_batch(train_idxs, data_y, data_y.train_idxs, data_y.train_lens, params.cuda)

      optimizer.zero_grad()
      m.train()
      nll_batch, kld_batch = m(padded_batch_x, batch_x_lens, padded_batch_y, batch_y_lens)

      cur_it += 1
      loss_batch, alpha = calc_loss_batch(params, nll_batch, kld_batch, cur_it, n_batch)

      loss_batch.backward()
      optimizer.step()

      out_parallel(i, j, n_batch, loss_batch, nll_batch, kld_batch, best_nll_dev, nll_dev, kld_dev, es.num_bad_epochs)
      update_tensorboard(writer, loss_batch, nll_batch, kld_batch, alpha, nll_dev, kld_dev, cur_it)

      if cur_it % params.VAL_EVERY == 0:
        sys.stdout.write('\n') 
        sys.stdout.flush()
        # validation 
        nll_dev, kld_dev = test(params, m, data_x, data_y)
        if es.step(nll_dev):
          print('\nEarly Stoped.')
          return
        elif es.is_better(nll_dev, best_nll_dev):
          best_nll_dev = nll_dev
          # save model
          m.save_embedding(params, data_x, 'x')
          m.save_embedding(params, data_y, 'y')
          m.save_model(params, data_x, data_y, optimizer)
示例#2
0
def train(params, m, datas):
  # early stopping
  es = EarlyStopping(mode = 'max', patience = params.cldc_patience)
  # set optimizer
  optimizer = get_optimizer(params, m)
  # get initial parameters
  if params.zs_reg_alpha > 0:
    init_param_dict = {k: v.detach().clone() for k, v in m.named_parameters() if v.requires_grad}

  # training
  train_lang, train_data = get_lang_data(params, datas, training = True)
  # dev & test are in the same lang
  test_lang, test_data = get_lang_data(params, datas)

  n_batch = train_data.train_size // params.cldc_bs if train_data.train_size % params.cldc_bs == 0 else train_data.train_size // params.cldc_bs + 1
  # get the same n_batch for unlabelled data as well
  # batch size for unlabelled data
  rest_cldc_bs = train_data.rest_train_size // n_batch
  # per category
  data_idxs = [list(range(len(train_idx))) for train_idx in train_data.train_idxs]
  rest_data_idxs = list(range(len(train_data.rest_train_idxs)))
 
  # number of iterations
  cur_it = 0
  # write to tensorboard
  writer = SummaryWriter('./history/{}'.format(params.log_path)) if params.write_tfboard else None
  # best dev/test
  bdev = 0
  btest = 0
  # current dev/test
  cdev = 0
  ctest = 0
  dev_class_acc = {}
  test_class_acc = {}
  dev_cm = None
  test_cm = None
  # early stopping warm up flag, start es after train loss below some threshold
  es_flag = False
  # set io function
  out_semicldc = getattr(ios, 'out_semicldc_{}'.format(params.cldc_train_mode))

  for i in range(params.cldc_ep):
    for data_idx in data_idxs:
      shuffle(data_idx)
    shuffle(rest_data_idxs)
    for j in range(n_batch):
      train_idxs = []
      for k, data_idx in enumerate(data_idxs):
        if j < n_batch - 1:
          train_idxs.append(data_idx[int(j * params.cldc_bs * train_data.train_prop[k]): int((j + 1) * params.cldc_bs * train_data.train_prop[k])])
          rest_train_idxs = rest_data_idxs[j * rest_cldc_bs: (j + 1) * rest_cldc_bs]
        elif j == n_batch - 1:
          train_idxs.append(data_idx[int(j * params.cldc_bs * train_data.train_prop[k]):])
          rest_train_idxs = rest_data_idxs[j * rest_cldc_bs:]

      # get batch data
      batch_train, batch_train_lens, batch_train_lb, batch_train_ohlb = get_batch(params, train_idxs, train_data.train_idxs, train_data.train_lens) 
      batch_rest_train, batch_rest_train_lens, batch_rest_train_lb, batch_rest_train_ohlb = get_rest_batch(params, rest_train_idxs, train_data.rest_train_idxs, train_data.rest_train_lens, enumerate_discrete)

      optimizer.zero_grad()
      m.train()

      if i + 1 <= params.cldc_warm_up_ep:
        m.warm_up = True
      else:
        m.warm_up = False

      loss_dict, batch_pred = m(train_lang, 
                                batch_train, batch_train_lens, batch_train_lb, batch_train_ohlb, 
                                batch_rest_train, batch_rest_train_lens, batch_rest_train_lb, batch_rest_train_ohlb)
      # regularization term
      if params.zs_reg_alpha > 0:
        reg_loss = .0
        for k, v in m.named_parameters():
          if k in init_param_dict and v.requires_grad:
            reg_loss += torch.sum((v - init_param_dict[k]) ** 2)
        print(reg_loss.detach())
        reg_loss *= params.zs_reg_alpha / 2
        reg_loss.backward()

      batch_acc, batch_acc_cls = get_classification_report(params, batch_train_lb.data.cpu().numpy(), batch_pred.data.cpu().numpy())

      if loss_dict['L_cldc_loss'] < params.cldc_lossth:
        es_flag = True

      #loss_dict['total_loss'].backward()
      out_semicldc(i, j, n_batch, loss_dict, batch_acc, batch_acc_cls, bdev, btest, cdev, ctest, es.num_bad_epochs)
      
      #torch.nn.utils.clip_grad_norm_(filter(lambda p: p.grad is not None and p.requires_grad, m.parameters()), 5)
      '''
      # debug for gradient
      for p_name, p in m.named_parameters():
        if p.grad is not None and p.requires_grad:
          print(p_name, p.grad.data.norm(2).item())
      '''

      optimizer.step()
      cur_it += 1
      update_tensorboard(params, writer, loss_dict, batch_acc, cdev, ctest, dev_class_acc, test_class_acc, cur_it)
      
      if cur_it % params.CLDC_VAL_EVERY == 0:
        sys.stdout.write('\n') 
        sys.stdout.flush()
        # validation 
        cdev, dev_class_acc, dev_cm = test(params, m, test_data.dev_idxs, test_data.dev_lens, test_data.dev_size, test_data.dev_prop, test_lang, cm = True)
        ctest, test_class_acc, test_cm = test(params, m, test_data.test_idxs, test_data.test_lens, test_data.test_size, test_data.test_prop, test_lang, cm = True)
        print(dev_cm)
        print(test_cm)
        if es.step(cdev):
          print('\nEarly Stoped.')
          # vis
          #if params.cldc_visualize:
            #tsne2d(params, m)
          # vis
          return
        elif es.is_better(cdev, bdev):
          bdev = cdev
          btest = ctest
          #save_model(params, m)
        # reset bad epochs
        if not es_flag:
          es.num_bad_epochs = 0
示例#3
0
def train(params, m, datas):
    es = EarlyStopping(min_delta=params.min_delta, patience=params.patience)

    # optimizer
    ps = [p[1] for p in m.named_parameters() if 'discriminator' not in p[0]]
    print('Model parameter: {}'.format(sum(p.numel() for p in ps)))
    optimizer = optim.Adam(ps, lr=params.init_lr)
    if params.adv_training:
        dis_ps = [
            p[1] for p in m.named_parameters() if 'discriminator' in p[0]
        ]
        dis_optimizer = optim.Adam(dis_ps, lr=params.init_lr)
        dis_enc_ps = [
            p[1] for p in m.named_parameters()
            if 'encoder' in p[0] or 'embedding' in p[0]
        ]
        dis_enc_optimizer = optim.Adam(dis_enc_ps, lr=params.init_lr)

    # all training instances, split between 2 languages, right now the data are balanced
    n_batch = len(datas) * datas[0].train_size // params.bs if len(
        datas) * datas[0].train_size % params.bs == 0 else len(
            datas) * datas[0].train_size // params.bs + 1
    data_idxs = {}
    for i, data in enumerate(datas):
        lang = data.vocab.lang
        data_idxs[lang] = list(range(data.train_size))

    # number of iterations
    cur_it = 0
    # write to tensorboard
    writer = SummaryWriter('./history/{}'.format(
        params.log_path)) if params.write_tfboard else None

    nll_dev = math.inf
    best_nll_dev = math.inf
    kld_dev = math.inf

    for i in range(params.ep):
        for lang in data_idxs:
            shuffle(data_idxs[lang])
        for j in range(n_batch):
            if params.task == 'xl' or params.task == 'xl-adv':
                lang_idx = j % len(datas)
                data = datas[lang_idx]
                lang = data.vocab.lang
                train_idxs = data_idxs[lang][j // len(datas) *
                                             params.bs:(j // len(datas) + 1) *
                                             params.bs]
            elif params.task == 'mo':
                lang = params.langs[0]
                lang_idx = params.lang_dict[lang]
                data = datas[lang_idx]
                train_idxs = data_idxs[lang][j * params.bs:(j + 1) * params.bs]
            padded_batch, batch_lens = get_batch(train_idxs, data,
                                                 data.train_idxs,
                                                 data.train_lens, params.cuda)

            optimizer.zero_grad()
            if params.adv_training:
                dis_optimizer.zero_grad()
                dis_enc_optimizer.zero_grad()
            m.train()

            nll_batch, kld_batch, ls_dis, ls_enc = m(lang, padded_batch,
                                                     batch_lens)

            cur_it += 1
            loss_batch, alpha = calc_loss_batch(params, nll_batch, kld_batch,
                                                cur_it, n_batch)
            '''
      # add adversarial loss to the encoder
      if cur_it > params.adv_ep * n_batch:
        loss_batch += ls_enc
      '''

            if not params.adv_training:
                loss_batch.backward()
                optimizer.step()
            else:
                ls_dis = ls_dis.mean()
                ls_enc = ls_enc.mean()
                loss_batch = loss_batch + ls_dis + ls_enc
                loss_batch.backward()
                optimizer.step()
                dis_optimizer.step()
                dis_enc_optimizer.step()

            out_xling(i,
                      j,
                      n_batch,
                      loss_batch,
                      nll_batch,
                      kld_batch,
                      best_nll_dev,
                      nll_dev,
                      kld_dev,
                      es.num_bad_epochs,
                      ls_dis=ls_dis,
                      ls_enc=ls_enc)
            update_tensorboard(writer,
                               loss_batch,
                               nll_batch,
                               kld_batch,
                               alpha,
                               nll_dev,
                               kld_dev,
                               cur_it,
                               ls_dis=ls_dis,
                               ls_enc=ls_enc)

            if cur_it % params.VAL_EVERY == 0:
                sys.stdout.write('\n')
                sys.stdout.flush()
                # validation
                nll_dev, kld_dev = test(params, m, datas)
                if es.step(nll_dev):
                    print('\nEarly Stoped.')
                    return
                elif es.is_better(nll_dev, best_nll_dev):
                    best_nll_dev = nll_dev
                    # save model
                    for lang in params.langs:
                        lang_idx = params.lang_dict[lang]
                        m.save_embedding(params, datas[lang_idx])
                    m.save_model(params, datas)
示例#4
0
def train(params, m, datas):
    # early stopping
    es = EarlyStopping(mode='max', patience=params.cldc_patience)
    # set optimizer
    optimizer = get_optimizer(params, m)

    # training on one lang, and dev/test for another lang
    # get training
    train_lang, train_data = get_lang_data(params, datas, training=True)
    # get dev and test, dev is the same language as test
    test_lang, test_data = get_lang_data(params, datas)

    n_batch = train_data.train_size // params.cldc_bs if train_data.train_size % params.cldc_bs == 0 else train_data.train_size // params.cldc_bs + 1
    # per category
    data_idxs = [
        list(range(len(train_idx))) for train_idx in train_data.train_idxs
    ]

    # number of iterations
    cur_it = 0
    # write to tensorboard
    writer = SummaryWriter('./history/{}'.format(
        params.log_path)) if params.write_tfboard else None
    # best xx
    bdev = 0
    btest = 0
    # current xx
    cdev = 0
    ctest = 0
    dev_class_acc = {}
    test_class_acc = {}
    dev_cm = None
    test_cm = None
    # early stopping warm up flag, start es after some iters
    es_flag = False

    for i in range(params.cldc_ep):
        for data_idx in data_idxs:
            shuffle(data_idx)
        for j in range(n_batch):
            train_idxs = []
            for k, data_idx in enumerate(data_idxs):
                if j < n_batch - 1:
                    train_idxs.append(
                        data_idx[int(j * params.cldc_bs *
                                     train_data.train_prop[k]):int(
                                         (j + 1) * params.cldc_bs *
                                         train_data.train_prop[k])])
                elif j == n_batch - 1:
                    train_idxs.append(data_idx[int(j * params.cldc_bs *
                                                   train_data.train_prop[k]):])

            batch_train, batch_train_lens, batch_train_lb = get_batch(
                params, train_idxs, train_data.train_idxs,
                train_data.train_lens)
            optimizer.zero_grad()
            m.train()

            cldc_loss_batch, _, batch_pred = m(train_lang, batch_train,
                                               batch_train_lens,
                                               batch_train_lb)

            batch_acc, batch_acc_cls = get_classification_report(
                params,
                batch_train_lb.data.cpu().numpy(),
                batch_pred.data.cpu().numpy())

            if cldc_loss_batch < params.cldc_lossth:
                es_flag = True

            cldc_loss_batch.backward()
            out_cldc(i, j, n_batch, cldc_loss_batch, batch_acc, batch_acc_cls,
                     bdev, btest, cdev, ctest, es.num_bad_epochs)

            optimizer.step()
            cur_it += 1
            update_tensorboard(writer, cldc_loss_batch, batch_acc, cdev, ctest,
                               dev_class_acc, test_class_acc, cur_it)

            if cur_it % params.CLDC_VAL_EVERY == 0:
                sys.stdout.write('\n')
                sys.stdout.flush()
                # validation
                #cdev, dev_class_acc, dev_cm = test(params, m, test_data.dev_idxs, test_data.dev_lens, test_data.dev_size, test_data.dev_prop, test_lang, cm = True)
                cdev, dev_class_acc, dev_cm = test(params,
                                                   m,
                                                   train_data.dev_idxs,
                                                   train_data.dev_lens,
                                                   train_data.dev_size,
                                                   train_data.dev_prop,
                                                   train_lang,
                                                   cm=True)
                ctest, test_class_acc, test_cm = test(params,
                                                      m,
                                                      test_data.test_idxs,
                                                      test_data.test_lens,
                                                      test_data.test_size,
                                                      test_data.test_prop,
                                                      test_lang,
                                                      cm=True)
                print(dev_cm)
                print(test_cm)
                if es.step(cdev):
                    print('\nEarly Stoped.')
                    return
                elif es.is_better(cdev, bdev):
                    bdev = cdev
                    btest = ctest
                    #save_model(params, m)
                # reset bad epochs
                if not es_flag:
                    es.num_bad_epochs = 0
def main(params, m, data):
    # early stopping
    es = EarlyStopping(mode='max', patience=params.patience)
    # set optimizer
    optimizer = get_optimizer(params, m)

    n_batch = data.train_size // params.bs if data.train_size % params.bs == 0 else data.train_size // params.bs + 1
    # per category
    data_idxs = [list(range(len(train_idx))) for train_idx in data.train_idxs]

    # number of iterations
    cur_it = 0
    # best xx
    bdev = 0
    btest = 0
    # current xx
    cdev = 0
    ctest = 0
    dev_class_acc = {}
    test_class_acc = {}
    dev_cm = None
    test_cm = None
    # early stopping warm up flag, start es after some iters
    es_flag = False

    for i in range(params.ep):
        # self-training
        if params.self_train or i >= params.semi_warm_up:
            params.self_train = True
            first_update = (i == params.semi_warm_up)
            # only for zero-shot
            if first_update:
                es.num_bad_epochs = 0
                es.best = 0
                bdev = 0
                btest = 0
            data = self_train_merge_data(params,
                                         m,
                                         es,
                                         data,
                                         first=first_update)
            n_batch = data.self_train_size // params.bs if data.self_train_size % params.bs == 0 else data.self_train_size // params.bs + 1
            # per category
            data_idxs = [
                list(range(len(train_idx)))
                for train_idx in data.self_train_idxs
            ]

        for data_idx in data_idxs:
            shuffle(data_idx)
        for j in range(n_batch):
            train_idxs = []
            for k, data_idx in enumerate(data_idxs):
                if params.self_train:
                    train_prop = data.self_train_prop
                else:
                    train_prop = data.train_prop
                if j < n_batch - 1:
                    train_idxs.append(
                        data_idx[int(j * params.bs *
                                     train_prop[k]):int((j + 1) * params.bs *
                                                        train_prop[k])])
                elif j == n_batch - 1:
                    train_idxs.append(data_idx[int(j * params.bs *
                                                   train_prop[k]):])

            if params.self_train:
                batch_train, _, batch_train_lb = get_batch(
                    params, train_idxs, data.self_train_idxs,
                    data.self_train_lens)
            else:
                batch_train, _, batch_train_lb = get_batch(
                    params, train_idxs, data.train_idxs, data.train_lens)
            optimizer.zero_grad()
            m.train()

            loss_batch, logits = m(batch_train, labels=batch_train_lb)
            batch_pred = torch.argmax(logits, dim=1)

            batch_acc, batch_acc_cls = get_classification_report(
                params,
                batch_train_lb.data.cpu().numpy(),
                batch_pred.data.cpu().numpy())

            if loss_batch < params.lossth:
                es_flag = True

            loss_batch.backward()
            out_cldc(i, j, n_batch, loss_batch, batch_acc, batch_acc_cls, bdev,
                     btest, cdev, ctest, es.num_bad_epochs)

            optimizer.step()
            cur_it += 1

        sys.stdout.write('\n')
        sys.stdout.flush()
        # validation
        cdev, dev_class_acc, dev_cm = test(params,
                                           m,
                                           data.dev_idxs,
                                           data.dev_lens,
                                           data.dev_size,
                                           data.dev_prop,
                                           cm=True)
        ctest, test_class_acc, test_cm = test(params,
                                              m,
                                              data.test_idxs,
                                              data.test_lens,
                                              data.test_size,
                                              data.test_prop,
                                              cm=True)
        print(dev_cm)
        print(test_cm)
        if es.step(cdev):
            print('\nEarly Stoped.')
            return
        elif es.is_better(cdev, bdev):
            bdev = cdev
            btest = ctest
        # reset bad epochs
        if not es_flag:
            es.num_bad_epochs = 0