def train(params, m, datas): # early stopping es = EarlyStopping(mode='max', patience=params.cldc_patience) # set optimizer optimizer = get_optimizer(params, m) # training on one lang, and dev/test for another lang # get training train_lang, train_data = get_lang_data(params, datas, training=True) # get dev and test, dev is the same language as test test_lang, test_data = get_lang_data(params, datas) n_batch = train_data.train_size // params.cldc_bs if train_data.train_size % params.cldc_bs == 0 else train_data.train_size // params.cldc_bs + 1 # per category data_idxs = [ list(range(len(train_idx))) for train_idx in train_data.train_idxs ] # number of iterations cur_it = 0 # write to tensorboard writer = SummaryWriter('./history/{}'.format( params.log_path)) if params.write_tfboard else None # best xx bdev = 0 btest = 0 # current xx cdev = 0 ctest = 0 dev_class_acc = {} test_class_acc = {} dev_cm = None test_cm = None # early stopping warm up flag, start es after some iters es_flag = False for i in range(params.cldc_ep): for data_idx in data_idxs: shuffle(data_idx) for j in range(n_batch): train_idxs = [] for k, data_idx in enumerate(data_idxs): if j < n_batch - 1: train_idxs.append( data_idx[int(j * params.cldc_bs * train_data.train_prop[k]):int( (j + 1) * params.cldc_bs * train_data.train_prop[k])]) elif j == n_batch - 1: train_idxs.append(data_idx[int(j * params.cldc_bs * train_data.train_prop[k]):]) batch_train, batch_train_lens, batch_train_lb = get_batch( params, train_idxs, train_data.train_idxs, train_data.train_lens) optimizer.zero_grad() m.train() cldc_loss_batch, _, batch_pred = m(train_lang, batch_train, batch_train_lens, batch_train_lb) batch_acc, batch_acc_cls = get_classification_report( params, batch_train_lb.data.cpu().numpy(), batch_pred.data.cpu().numpy()) if cldc_loss_batch < params.cldc_lossth: es_flag = True cldc_loss_batch.backward() out_cldc(i, j, n_batch, cldc_loss_batch, batch_acc, batch_acc_cls, bdev, btest, cdev, ctest, es.num_bad_epochs) optimizer.step() cur_it += 1 update_tensorboard(writer, cldc_loss_batch, batch_acc, cdev, ctest, dev_class_acc, test_class_acc, cur_it) if cur_it % params.CLDC_VAL_EVERY == 0: sys.stdout.write('\n') sys.stdout.flush() # validation #cdev, dev_class_acc, dev_cm = test(params, m, test_data.dev_idxs, test_data.dev_lens, test_data.dev_size, test_data.dev_prop, test_lang, cm = True) cdev, dev_class_acc, dev_cm = test(params, m, train_data.dev_idxs, train_data.dev_lens, train_data.dev_size, train_data.dev_prop, train_lang, cm=True) ctest, test_class_acc, test_cm = test(params, m, test_data.test_idxs, test_data.test_lens, test_data.test_size, test_data.test_prop, test_lang, cm=True) print(dev_cm) print(test_cm) if es.step(cdev): print('\nEarly Stoped.') return elif es.is_better(cdev, bdev): bdev = cdev btest = ctest #save_model(params, m) # reset bad epochs if not es_flag: es.num_bad_epochs = 0
def train(params, m, datas): # early stopping es = EarlyStopping(mode = 'max', patience = params.cldc_patience) # set optimizer optimizer = get_optimizer(params, m) # get initial parameters if params.zs_reg_alpha > 0: init_param_dict = {k: v.detach().clone() for k, v in m.named_parameters() if v.requires_grad} # training train_lang, train_data = get_lang_data(params, datas, training = True) # dev & test are in the same lang test_lang, test_data = get_lang_data(params, datas) n_batch = train_data.train_size // params.cldc_bs if train_data.train_size % params.cldc_bs == 0 else train_data.train_size // params.cldc_bs + 1 # get the same n_batch for unlabelled data as well # batch size for unlabelled data rest_cldc_bs = train_data.rest_train_size // n_batch # per category data_idxs = [list(range(len(train_idx))) for train_idx in train_data.train_idxs] rest_data_idxs = list(range(len(train_data.rest_train_idxs))) # number of iterations cur_it = 0 # write to tensorboard writer = SummaryWriter('./history/{}'.format(params.log_path)) if params.write_tfboard else None # best dev/test bdev = 0 btest = 0 # current dev/test cdev = 0 ctest = 0 dev_class_acc = {} test_class_acc = {} dev_cm = None test_cm = None # early stopping warm up flag, start es after train loss below some threshold es_flag = False # set io function out_semicldc = getattr(ios, 'out_semicldc_{}'.format(params.cldc_train_mode)) for i in range(params.cldc_ep): for data_idx in data_idxs: shuffle(data_idx) shuffle(rest_data_idxs) for j in range(n_batch): train_idxs = [] for k, data_idx in enumerate(data_idxs): if j < n_batch - 1: train_idxs.append(data_idx[int(j * params.cldc_bs * train_data.train_prop[k]): int((j + 1) * params.cldc_bs * train_data.train_prop[k])]) rest_train_idxs = rest_data_idxs[j * rest_cldc_bs: (j + 1) * rest_cldc_bs] elif j == n_batch - 1: train_idxs.append(data_idx[int(j * params.cldc_bs * train_data.train_prop[k]):]) rest_train_idxs = rest_data_idxs[j * rest_cldc_bs:] # get batch data batch_train, batch_train_lens, batch_train_lb, batch_train_ohlb = get_batch(params, train_idxs, train_data.train_idxs, train_data.train_lens) batch_rest_train, batch_rest_train_lens, batch_rest_train_lb, batch_rest_train_ohlb = get_rest_batch(params, rest_train_idxs, train_data.rest_train_idxs, train_data.rest_train_lens, enumerate_discrete) optimizer.zero_grad() m.train() if i + 1 <= params.cldc_warm_up_ep: m.warm_up = True else: m.warm_up = False loss_dict, batch_pred = m(train_lang, batch_train, batch_train_lens, batch_train_lb, batch_train_ohlb, batch_rest_train, batch_rest_train_lens, batch_rest_train_lb, batch_rest_train_ohlb) # regularization term if params.zs_reg_alpha > 0: reg_loss = .0 for k, v in m.named_parameters(): if k in init_param_dict and v.requires_grad: reg_loss += torch.sum((v - init_param_dict[k]) ** 2) print(reg_loss.detach()) reg_loss *= params.zs_reg_alpha / 2 reg_loss.backward() batch_acc, batch_acc_cls = get_classification_report(params, batch_train_lb.data.cpu().numpy(), batch_pred.data.cpu().numpy()) if loss_dict['L_cldc_loss'] < params.cldc_lossth: es_flag = True #loss_dict['total_loss'].backward() out_semicldc(i, j, n_batch, loss_dict, batch_acc, batch_acc_cls, bdev, btest, cdev, ctest, es.num_bad_epochs) #torch.nn.utils.clip_grad_norm_(filter(lambda p: p.grad is not None and p.requires_grad, m.parameters()), 5) ''' # debug for gradient for p_name, p in m.named_parameters(): if p.grad is not None and p.requires_grad: print(p_name, p.grad.data.norm(2).item()) ''' optimizer.step() cur_it += 1 update_tensorboard(params, writer, loss_dict, batch_acc, cdev, ctest, dev_class_acc, test_class_acc, cur_it) if cur_it % params.CLDC_VAL_EVERY == 0: sys.stdout.write('\n') sys.stdout.flush() # validation cdev, dev_class_acc, dev_cm = test(params, m, test_data.dev_idxs, test_data.dev_lens, test_data.dev_size, test_data.dev_prop, test_lang, cm = True) ctest, test_class_acc, test_cm = test(params, m, test_data.test_idxs, test_data.test_lens, test_data.test_size, test_data.test_prop, test_lang, cm = True) print(dev_cm) print(test_cm) if es.step(cdev): print('\nEarly Stoped.') # vis #if params.cldc_visualize: #tsne2d(params, m) # vis return elif es.is_better(cdev, bdev): bdev = cdev btest = ctest #save_model(params, m) # reset bad epochs if not es_flag: es.num_bad_epochs = 0
def main(params, m, data): # early stopping es = EarlyStopping(mode='max', patience=params.patience) # set optimizer optimizer = get_optimizer(params, m) n_batch = data.train_size // params.bs if data.train_size % params.bs == 0 else data.train_size // params.bs + 1 # per category data_idxs = [list(range(len(train_idx))) for train_idx in data.train_idxs] # number of iterations cur_it = 0 # best xx bdev = 0 btest = 0 # current xx cdev = 0 ctest = 0 dev_class_acc = {} test_class_acc = {} dev_cm = None test_cm = None # early stopping warm up flag, start es after some iters es_flag = False for i in range(params.ep): # self-training if params.self_train or i >= params.semi_warm_up: params.self_train = True first_update = (i == params.semi_warm_up) # only for zero-shot if first_update: es.num_bad_epochs = 0 es.best = 0 bdev = 0 btest = 0 data = self_train_merge_data(params, m, es, data, first=first_update) n_batch = data.self_train_size // params.bs if data.self_train_size % params.bs == 0 else data.self_train_size // params.bs + 1 # per category data_idxs = [ list(range(len(train_idx))) for train_idx in data.self_train_idxs ] for data_idx in data_idxs: shuffle(data_idx) for j in range(n_batch): train_idxs = [] for k, data_idx in enumerate(data_idxs): if params.self_train: train_prop = data.self_train_prop else: train_prop = data.train_prop if j < n_batch - 1: train_idxs.append( data_idx[int(j * params.bs * train_prop[k]):int((j + 1) * params.bs * train_prop[k])]) elif j == n_batch - 1: train_idxs.append(data_idx[int(j * params.bs * train_prop[k]):]) if params.self_train: batch_train, _, batch_train_lb = get_batch( params, train_idxs, data.self_train_idxs, data.self_train_lens) else: batch_train, _, batch_train_lb = get_batch( params, train_idxs, data.train_idxs, data.train_lens) optimizer.zero_grad() m.train() loss_batch, logits = m(batch_train, labels=batch_train_lb) batch_pred = torch.argmax(logits, dim=1) batch_acc, batch_acc_cls = get_classification_report( params, batch_train_lb.data.cpu().numpy(), batch_pred.data.cpu().numpy()) if loss_batch < params.lossth: es_flag = True loss_batch.backward() out_cldc(i, j, n_batch, loss_batch, batch_acc, batch_acc_cls, bdev, btest, cdev, ctest, es.num_bad_epochs) optimizer.step() cur_it += 1 sys.stdout.write('\n') sys.stdout.flush() # validation cdev, dev_class_acc, dev_cm = test(params, m, data.dev_idxs, data.dev_lens, data.dev_size, data.dev_prop, cm=True) ctest, test_class_acc, test_cm = test(params, m, data.test_idxs, data.test_lens, data.test_size, data.test_prop, cm=True) print(dev_cm) print(test_cm) if es.step(cdev): print('\nEarly Stoped.') return elif es.is_better(cdev, bdev): bdev = cdev btest = ctest # reset bad epochs if not es_flag: es.num_bad_epochs = 0