def train(params, m, data_x, data_y): es = EarlyStopping(min_delta = params.min_delta, patience = params.patience) # optimizer optimizer = optim.Adam(filter(lambda p: p.requires_grad, m.parameters()), lr = params.init_learning_rate) n_batch = data_x.train_size // params.bs if data_x.train_size % params.bs == 0 else data_x.train_size // params.bs + 1 data_idxs = list(range(data_x.train_size)) # number of iterations cur_it = 0 # write to tensorboard writer = SummaryWriter('./history/{}'.format(params.emb_out_path)) if params.write_tfboard else None nll_dev = math.inf best_nll_dev = math.inf kld_dev = math.inf for i in range(params.ep): shuffle(data_idxs) for j in range(n_batch): train_idxs = data_idxs[j * params.bs: (j + 1) * params.bs] # get padded & sorted batch idxs and padded_batch_x, batch_x_lens = get_batch(train_idxs, data_x, data_x.train_idxs, data_x.train_lens, params.cuda) padded_batch_y, batch_y_lens = get_batch(train_idxs, data_y, data_y.train_idxs, data_y.train_lens, params.cuda) optimizer.zero_grad() m.train() nll_batch, kld_batch = m(padded_batch_x, batch_x_lens, padded_batch_y, batch_y_lens) cur_it += 1 loss_batch, alpha = calc_loss_batch(params, nll_batch, kld_batch, cur_it, n_batch) loss_batch.backward() optimizer.step() out_parallel(i, j, n_batch, loss_batch, nll_batch, kld_batch, best_nll_dev, nll_dev, kld_dev, es.num_bad_epochs) update_tensorboard(writer, loss_batch, nll_batch, kld_batch, alpha, nll_dev, kld_dev, cur_it) if cur_it % params.VAL_EVERY == 0: sys.stdout.write('\n') sys.stdout.flush() # validation nll_dev, kld_dev = test(params, m, data_x, data_y) if es.step(nll_dev): print('\nEarly Stoped.') return elif es.is_better(nll_dev, best_nll_dev): best_nll_dev = nll_dev # save model m.save_embedding(params, data_x, 'x') m.save_embedding(params, data_y, 'y') m.save_model(params, data_x, data_y, optimizer)
def train(params, m, datas): # early stopping es = EarlyStopping(mode = 'max', patience = params.cldc_patience) # set optimizer optimizer = get_optimizer(params, m) # get initial parameters if params.zs_reg_alpha > 0: init_param_dict = {k: v.detach().clone() for k, v in m.named_parameters() if v.requires_grad} # training train_lang, train_data = get_lang_data(params, datas, training = True) # dev & test are in the same lang test_lang, test_data = get_lang_data(params, datas) n_batch = train_data.train_size // params.cldc_bs if train_data.train_size % params.cldc_bs == 0 else train_data.train_size // params.cldc_bs + 1 # get the same n_batch for unlabelled data as well # batch size for unlabelled data rest_cldc_bs = train_data.rest_train_size // n_batch # per category data_idxs = [list(range(len(train_idx))) for train_idx in train_data.train_idxs] rest_data_idxs = list(range(len(train_data.rest_train_idxs))) # number of iterations cur_it = 0 # write to tensorboard writer = SummaryWriter('./history/{}'.format(params.log_path)) if params.write_tfboard else None # best dev/test bdev = 0 btest = 0 # current dev/test cdev = 0 ctest = 0 dev_class_acc = {} test_class_acc = {} dev_cm = None test_cm = None # early stopping warm up flag, start es after train loss below some threshold es_flag = False # set io function out_semicldc = getattr(ios, 'out_semicldc_{}'.format(params.cldc_train_mode)) for i in range(params.cldc_ep): for data_idx in data_idxs: shuffle(data_idx) shuffle(rest_data_idxs) for j in range(n_batch): train_idxs = [] for k, data_idx in enumerate(data_idxs): if j < n_batch - 1: train_idxs.append(data_idx[int(j * params.cldc_bs * train_data.train_prop[k]): int((j + 1) * params.cldc_bs * train_data.train_prop[k])]) rest_train_idxs = rest_data_idxs[j * rest_cldc_bs: (j + 1) * rest_cldc_bs] elif j == n_batch - 1: train_idxs.append(data_idx[int(j * params.cldc_bs * train_data.train_prop[k]):]) rest_train_idxs = rest_data_idxs[j * rest_cldc_bs:] # get batch data batch_train, batch_train_lens, batch_train_lb, batch_train_ohlb = get_batch(params, train_idxs, train_data.train_idxs, train_data.train_lens) batch_rest_train, batch_rest_train_lens, batch_rest_train_lb, batch_rest_train_ohlb = get_rest_batch(params, rest_train_idxs, train_data.rest_train_idxs, train_data.rest_train_lens, enumerate_discrete) optimizer.zero_grad() m.train() if i + 1 <= params.cldc_warm_up_ep: m.warm_up = True else: m.warm_up = False loss_dict, batch_pred = m(train_lang, batch_train, batch_train_lens, batch_train_lb, batch_train_ohlb, batch_rest_train, batch_rest_train_lens, batch_rest_train_lb, batch_rest_train_ohlb) # regularization term if params.zs_reg_alpha > 0: reg_loss = .0 for k, v in m.named_parameters(): if k in init_param_dict and v.requires_grad: reg_loss += torch.sum((v - init_param_dict[k]) ** 2) print(reg_loss.detach()) reg_loss *= params.zs_reg_alpha / 2 reg_loss.backward() batch_acc, batch_acc_cls = get_classification_report(params, batch_train_lb.data.cpu().numpy(), batch_pred.data.cpu().numpy()) if loss_dict['L_cldc_loss'] < params.cldc_lossth: es_flag = True #loss_dict['total_loss'].backward() out_semicldc(i, j, n_batch, loss_dict, batch_acc, batch_acc_cls, bdev, btest, cdev, ctest, es.num_bad_epochs) #torch.nn.utils.clip_grad_norm_(filter(lambda p: p.grad is not None and p.requires_grad, m.parameters()), 5) ''' # debug for gradient for p_name, p in m.named_parameters(): if p.grad is not None and p.requires_grad: print(p_name, p.grad.data.norm(2).item()) ''' optimizer.step() cur_it += 1 update_tensorboard(params, writer, loss_dict, batch_acc, cdev, ctest, dev_class_acc, test_class_acc, cur_it) if cur_it % params.CLDC_VAL_EVERY == 0: sys.stdout.write('\n') sys.stdout.flush() # validation cdev, dev_class_acc, dev_cm = test(params, m, test_data.dev_idxs, test_data.dev_lens, test_data.dev_size, test_data.dev_prop, test_lang, cm = True) ctest, test_class_acc, test_cm = test(params, m, test_data.test_idxs, test_data.test_lens, test_data.test_size, test_data.test_prop, test_lang, cm = True) print(dev_cm) print(test_cm) if es.step(cdev): print('\nEarly Stoped.') # vis #if params.cldc_visualize: #tsne2d(params, m) # vis return elif es.is_better(cdev, bdev): bdev = cdev btest = ctest #save_model(params, m) # reset bad epochs if not es_flag: es.num_bad_epochs = 0
def train(params, m, datas): es = EarlyStopping(min_delta=params.min_delta, patience=params.patience) # optimizer ps = [p[1] for p in m.named_parameters() if 'discriminator' not in p[0]] print('Model parameter: {}'.format(sum(p.numel() for p in ps))) optimizer = optim.Adam(ps, lr=params.init_lr) if params.adv_training: dis_ps = [ p[1] for p in m.named_parameters() if 'discriminator' in p[0] ] dis_optimizer = optim.Adam(dis_ps, lr=params.init_lr) dis_enc_ps = [ p[1] for p in m.named_parameters() if 'encoder' in p[0] or 'embedding' in p[0] ] dis_enc_optimizer = optim.Adam(dis_enc_ps, lr=params.init_lr) # all training instances, split between 2 languages, right now the data are balanced n_batch = len(datas) * datas[0].train_size // params.bs if len( datas) * datas[0].train_size % params.bs == 0 else len( datas) * datas[0].train_size // params.bs + 1 data_idxs = {} for i, data in enumerate(datas): lang = data.vocab.lang data_idxs[lang] = list(range(data.train_size)) # number of iterations cur_it = 0 # write to tensorboard writer = SummaryWriter('./history/{}'.format( params.log_path)) if params.write_tfboard else None nll_dev = math.inf best_nll_dev = math.inf kld_dev = math.inf for i in range(params.ep): for lang in data_idxs: shuffle(data_idxs[lang]) for j in range(n_batch): if params.task == 'xl' or params.task == 'xl-adv': lang_idx = j % len(datas) data = datas[lang_idx] lang = data.vocab.lang train_idxs = data_idxs[lang][j // len(datas) * params.bs:(j // len(datas) + 1) * params.bs] elif params.task == 'mo': lang = params.langs[0] lang_idx = params.lang_dict[lang] data = datas[lang_idx] train_idxs = data_idxs[lang][j * params.bs:(j + 1) * params.bs] padded_batch, batch_lens = get_batch(train_idxs, data, data.train_idxs, data.train_lens, params.cuda) optimizer.zero_grad() if params.adv_training: dis_optimizer.zero_grad() dis_enc_optimizer.zero_grad() m.train() nll_batch, kld_batch, ls_dis, ls_enc = m(lang, padded_batch, batch_lens) cur_it += 1 loss_batch, alpha = calc_loss_batch(params, nll_batch, kld_batch, cur_it, n_batch) ''' # add adversarial loss to the encoder if cur_it > params.adv_ep * n_batch: loss_batch += ls_enc ''' if not params.adv_training: loss_batch.backward() optimizer.step() else: ls_dis = ls_dis.mean() ls_enc = ls_enc.mean() loss_batch = loss_batch + ls_dis + ls_enc loss_batch.backward() optimizer.step() dis_optimizer.step() dis_enc_optimizer.step() out_xling(i, j, n_batch, loss_batch, nll_batch, kld_batch, best_nll_dev, nll_dev, kld_dev, es.num_bad_epochs, ls_dis=ls_dis, ls_enc=ls_enc) update_tensorboard(writer, loss_batch, nll_batch, kld_batch, alpha, nll_dev, kld_dev, cur_it, ls_dis=ls_dis, ls_enc=ls_enc) if cur_it % params.VAL_EVERY == 0: sys.stdout.write('\n') sys.stdout.flush() # validation nll_dev, kld_dev = test(params, m, datas) if es.step(nll_dev): print('\nEarly Stoped.') return elif es.is_better(nll_dev, best_nll_dev): best_nll_dev = nll_dev # save model for lang in params.langs: lang_idx = params.lang_dict[lang] m.save_embedding(params, datas[lang_idx]) m.save_model(params, datas)
def train(params, m, datas): # early stopping es = EarlyStopping(mode='max', patience=params.cldc_patience) # set optimizer optimizer = get_optimizer(params, m) # training on one lang, and dev/test for another lang # get training train_lang, train_data = get_lang_data(params, datas, training=True) # get dev and test, dev is the same language as test test_lang, test_data = get_lang_data(params, datas) n_batch = train_data.train_size // params.cldc_bs if train_data.train_size % params.cldc_bs == 0 else train_data.train_size // params.cldc_bs + 1 # per category data_idxs = [ list(range(len(train_idx))) for train_idx in train_data.train_idxs ] # number of iterations cur_it = 0 # write to tensorboard writer = SummaryWriter('./history/{}'.format( params.log_path)) if params.write_tfboard else None # best xx bdev = 0 btest = 0 # current xx cdev = 0 ctest = 0 dev_class_acc = {} test_class_acc = {} dev_cm = None test_cm = None # early stopping warm up flag, start es after some iters es_flag = False for i in range(params.cldc_ep): for data_idx in data_idxs: shuffle(data_idx) for j in range(n_batch): train_idxs = [] for k, data_idx in enumerate(data_idxs): if j < n_batch - 1: train_idxs.append( data_idx[int(j * params.cldc_bs * train_data.train_prop[k]):int( (j + 1) * params.cldc_bs * train_data.train_prop[k])]) elif j == n_batch - 1: train_idxs.append(data_idx[int(j * params.cldc_bs * train_data.train_prop[k]):]) batch_train, batch_train_lens, batch_train_lb = get_batch( params, train_idxs, train_data.train_idxs, train_data.train_lens) optimizer.zero_grad() m.train() cldc_loss_batch, _, batch_pred = m(train_lang, batch_train, batch_train_lens, batch_train_lb) batch_acc, batch_acc_cls = get_classification_report( params, batch_train_lb.data.cpu().numpy(), batch_pred.data.cpu().numpy()) if cldc_loss_batch < params.cldc_lossth: es_flag = True cldc_loss_batch.backward() out_cldc(i, j, n_batch, cldc_loss_batch, batch_acc, batch_acc_cls, bdev, btest, cdev, ctest, es.num_bad_epochs) optimizer.step() cur_it += 1 update_tensorboard(writer, cldc_loss_batch, batch_acc, cdev, ctest, dev_class_acc, test_class_acc, cur_it) if cur_it % params.CLDC_VAL_EVERY == 0: sys.stdout.write('\n') sys.stdout.flush() # validation #cdev, dev_class_acc, dev_cm = test(params, m, test_data.dev_idxs, test_data.dev_lens, test_data.dev_size, test_data.dev_prop, test_lang, cm = True) cdev, dev_class_acc, dev_cm = test(params, m, train_data.dev_idxs, train_data.dev_lens, train_data.dev_size, train_data.dev_prop, train_lang, cm=True) ctest, test_class_acc, test_cm = test(params, m, test_data.test_idxs, test_data.test_lens, test_data.test_size, test_data.test_prop, test_lang, cm=True) print(dev_cm) print(test_cm) if es.step(cdev): print('\nEarly Stoped.') return elif es.is_better(cdev, bdev): bdev = cdev btest = ctest #save_model(params, m) # reset bad epochs if not es_flag: es.num_bad_epochs = 0
def main(params, m, data): # early stopping es = EarlyStopping(mode='max', patience=params.patience) # set optimizer optimizer = get_optimizer(params, m) n_batch = data.train_size // params.bs if data.train_size % params.bs == 0 else data.train_size // params.bs + 1 # per category data_idxs = [list(range(len(train_idx))) for train_idx in data.train_idxs] # number of iterations cur_it = 0 # best xx bdev = 0 btest = 0 # current xx cdev = 0 ctest = 0 dev_class_acc = {} test_class_acc = {} dev_cm = None test_cm = None # early stopping warm up flag, start es after some iters es_flag = False for i in range(params.ep): # self-training if params.self_train or i >= params.semi_warm_up: params.self_train = True first_update = (i == params.semi_warm_up) # only for zero-shot if first_update: es.num_bad_epochs = 0 es.best = 0 bdev = 0 btest = 0 data = self_train_merge_data(params, m, es, data, first=first_update) n_batch = data.self_train_size // params.bs if data.self_train_size % params.bs == 0 else data.self_train_size // params.bs + 1 # per category data_idxs = [ list(range(len(train_idx))) for train_idx in data.self_train_idxs ] for data_idx in data_idxs: shuffle(data_idx) for j in range(n_batch): train_idxs = [] for k, data_idx in enumerate(data_idxs): if params.self_train: train_prop = data.self_train_prop else: train_prop = data.train_prop if j < n_batch - 1: train_idxs.append( data_idx[int(j * params.bs * train_prop[k]):int((j + 1) * params.bs * train_prop[k])]) elif j == n_batch - 1: train_idxs.append(data_idx[int(j * params.bs * train_prop[k]):]) if params.self_train: batch_train, _, batch_train_lb = get_batch( params, train_idxs, data.self_train_idxs, data.self_train_lens) else: batch_train, _, batch_train_lb = get_batch( params, train_idxs, data.train_idxs, data.train_lens) optimizer.zero_grad() m.train() loss_batch, logits = m(batch_train, labels=batch_train_lb) batch_pred = torch.argmax(logits, dim=1) batch_acc, batch_acc_cls = get_classification_report( params, batch_train_lb.data.cpu().numpy(), batch_pred.data.cpu().numpy()) if loss_batch < params.lossth: es_flag = True loss_batch.backward() out_cldc(i, j, n_batch, loss_batch, batch_acc, batch_acc_cls, bdev, btest, cdev, ctest, es.num_bad_epochs) optimizer.step() cur_it += 1 sys.stdout.write('\n') sys.stdout.flush() # validation cdev, dev_class_acc, dev_cm = test(params, m, data.dev_idxs, data.dev_lens, data.dev_size, data.dev_prop, cm=True) ctest, test_class_acc, test_cm = test(params, m, data.test_idxs, data.test_lens, data.test_size, data.test_prop, cm=True) print(dev_cm) print(test_cm) if es.step(cdev): print('\nEarly Stoped.') return elif es.is_better(cdev, bdev): bdev = cdev btest = ctest # reset bad epochs if not es_flag: es.num_bad_epochs = 0