def train( model: GP, train_x: torch.Tensor, train_y: torch.Tensor, num_iters: int, lr: float = 0.1, show_progress: bool = True, ): """Trains the provided model by maximising the marginal likelihood.""" model.train() optimizer = AdamW(model.parameters(), lr=lr) mll = gp.mlls.ExactMarginalLogLikelihood(model.likelihood, model) loss = 0 iterator = ( tqdm(range(num_iters), desc="Epoch") if show_progress else range(num_iters) ) for _ in iterator: optimizer.zero_grad() output = model(train_x) loss = -mll(output, train_y) loss.backward() optimizer.step() if show_progress: iterator.set_postfix(loss=loss.item()) return loss.detach().cpu().item()
def train(path: str, epochs: int = 3) -> None: LR = 1e-3 DECAY = 1e-4 dataset = MNISTDataset() loader = MNISTLoader() model = superdupermodel().cuda() criterion = nn.CrossEntropyLoss().cuda() optimizer = AdamW(model.parameters(), lr=LR, weight_decay=DECAY) for epoch in tqdm(range(epochs), desc="Epoch", position=0): total_loss = 0.0 total_acc = 0.0 model = model.train() with tqdm(loader.train, "Train", position=1) as pbar: for inputs, labels in pbar: inputs, labels = inputs.cuda(), labels.cuda() optimizer.zero_grad() out = model(inputs) loss = criterion(out, labels) acc = (torch.argmax(torch.softmax(out, 1), 1) == labels).sum() loss.backward() optimizer.step() total_loss += loss.item() / len(loader.train) total_acc += acc.item() / len(dataset.train) pbar.set_postfix( loss=f"{total_loss:.2e}", acc=f"{total_acc * 100:.2f}%", ) with torch.no_grad(): total_loss = 0.0 total_acc = 0.0 model = model.eval() with tqdm(loader.test, "Valid", position=2) as pbar: for inputs, labels in pbar: inputs, labels = inputs.cuda(), labels.cuda() out = model(inputs) loss = criterion(out, labels) acc = (torch.argmax(torch.softmax(out, 1), 1) == labels).sum() total_loss += loss.item() / len(loader.test) total_acc += acc.item() / len(dataset.test) pbar.set_postfix( loss=f"{total_loss:.2e}", acc=f"{total_acc * 100:.2f}%", ) torch.save(model.state_dict(), f"{path}.pth")
def decompose(self, conv, pw, dw, lr=0.001, steps=600): """ GEP decompose standard convolution kernel :param conv: standard convolution kernel :param pw: decomposed pointwise convolution kernel :param dw: decomposed depthwise convolution kernel :param lr: learning rate :param steps: training steps for decomposing """ conv.requires_grad = False pw.requires_grad = True dw.requires_grad = True criterion = nn.MSELoss() optimizer = AdamW({pw, dw}, lr=lr) st = time.time() for s in range(steps): if steps in (400, 700): lr = lr / 10 optimizer = AdamW({pw, dw}, lr=lr) optimizer.zero_grad() kernel_pred = pw.cuda() * dw.cuda() loss = criterion(kernel_pred, conv.cuda()) loss.backward() optimizer.step() if s % 100 == 99: print('loss = %f, time = %d' % (loss, (time.time() - st))) st = time.time()
def train(path: str, epochs: int = 3) -> None: dataset = MNISTDataset() loader = MNISTLoader() model = LeNet5(n_classes=10).cuda() criterion = nn.CrossEntropyLoss().cuda() optimizer = AdamW(model.parameters(), lr=1e-3) for epoch in tqdm(range(epochs), desc="Epoch"): model.train() total_loss, total_acc = 0, 0 pbar = tqdm(loader.train, desc="Train") for img, label in pbar: img, label = img.cuda(), label.cuda() optimizer.zero_grad() preds = model(img) loss = criterion(preds, label) acc = (torch.argmax(torch.softmax( preds, dim=1, ), dim=1) == label).sum() loss.backward() optimizer.step() total_loss += loss.item() / len(loader.train) total_acc += acc.item() / len(dataset.train) pbar.set_postfix( loss=f"{total_loss:.2e}", acc=f"{total_acc * 100:.2f}%", ) model.eval() total_loss, total_acc = 0, 0 pbar = tqdm(loader.test, desc="Test") for img, label in pbar: img, label = img.cuda(), label.cuda() preds = model(img) loss = criterion(preds, label) acc = (torch.argmax(torch.softmax( preds, dim=1, ), dim=1) == label).sum() total_loss += loss.item() / len(loader.test) total_acc += acc.item() / len(dataset.test) pbar.set_postfix( loss=f"{total_loss:.2e}", acc=f"{total_acc * 100:.2f}%", ) torch.save(model.state_dict(), f"{path}.pth")
class Scheduler: def __init__(self, model, args): super(Scheduler, self).__init__() self.loss = Loss() self.optimizer = AdamW(model.parameters(), lr=args.lr, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay) self.warm_up = args.warm_up self.curr_step = 0 self.init_lr = args.lr self.curr_loss = None def __call__(self, out_mask_lm, out_nsp, target): mask_pos, mask_label, nsp_label = target mask_pos = mask_pos.unsqueeze(-1).expand(mask_pos.size(0), mask_pos.size(1), out_mask_lm.size(-1)) out_mask_lm = torch.gather(out_mask_lm, 1, mask_pos) nsp_label = nsp_label.long() # calculate loss loss_nsp = self.loss(out_nsp, nsp_label) loss_mask_lm = self.loss(out_mask_lm.transpose(1, 2), mask_label) self.curr_loss = loss_mask_lm + loss_nsp # calculate acc pred_mask_lm = out_mask_lm[:, :, :].max(dim=-1)[1] pred_nsp_lm = out_nsp[:, :].max(dim=-1)[1] mask_lm_acc = pred_mask_lm.eq(mask_label).sum() / len( pred_mask_lm.view(-1)) nsp_acc = pred_nsp_lm.eq(nsp_label).sum() / len(pred_nsp_lm.view(-1)) return self.curr_loss.data, mask_lm_acc, nsp_acc def step(self, epoch): self.curr_loss.backward() self._update(epoch) self.optimizer.step() self.optimizer.zero_grad() def _update(self, epoch): self.curr_step = epoch lr = self.init_lr * self._lr_scale() for param_group in self.optimizer.param_groups: param_group['lr'] = lr def _lr_scale(self): # if self.curr_step < self.warm_up: # return 1 # else: # return 2 ** -((self.curr_step - self.warm_up) // 35) return 1
def coteaching(train_xs, train_ys, test_xs, test_ys): train_xs = np.moveaxis(train_xs, 3, 1) test_xs = np.moveaxis(test_xs, 3, 1) batch_size = 1024 train_loader = DataLoader(TensorDataset(FloatTensor(train_xs), LongTensor(train_ys)), batch_size=batch_size) n_epoch, forget_rate = 100, 0.2 rate_schedule = np.ones(n_epoch) * forget_rate rate_schedule[0] = 0.0 device = 'cuda:0' model1 = models.resnet18(num_classes=2).to(device) optim1 = AdamW(model1.parameters()) model2 = models.resnet18(num_classes=2).to(device) optim2 = AdamW(model2.parameters()) for epoch in range(1, n_epoch): iters, acc1, acc2 = 0, 0, 0 for (images, labels) in train_loader: images = Variable(images).to(device) labels = Variable(labels).to(device) iters += 1 logits1 = model1(images) acc1 += accuracy(logits1, labels, batch_size) logits2 = model2(images) acc2 += accuracy(logits2, labels, batch_size) loss_1, loss_2 = loss_coteaching(logits1, logits2, labels, rate_schedule[epoch]) optim1.zero_grad() loss_1.backward() optim1.step() optim2.zero_grad() loss_2.backward() optim2.step() printr('Coteaching: Epoch {}: acc1 {:.4f}, acc2 {:.4f}'.format( epoch, acc1 / iters, acc2 / iters)) printr('') test_loader = DataLoader(TensorDataset(FloatTensor(test_xs), LongTensor(test_ys)), batch_size=1024) def _eval(model): total, correct = 0, 0 model.eval() for images, labels in test_loader: _, preds = max(F.softmax(model(images.cuda()), dim=1).data, 1) total += len(labels) correct += int((preds.cpu() == labels).sum()) return correct / total return (_eval(model1) + _eval(model2)) / 2
def train( path: str, save_all: bool, epochs: int = 3, ) -> None: dataset = MNISTDataset() loader = MNISTLoader() # model = Model() model = LeNet5(10) criterion = nn.CrossEntropyLoss() optim = AdamW(model.parameters(), lr=1e-3) best_acc = 0.0 acc = 0.0 for epoch in tqdm(range(epochs), desc="Epoch"): model.train() with tqdm(loader.trainloader, desc="Train") as pbar: total_loss = 0.0 acc = 0.0 for img, label in pbar: optim.zero_grad() output = model(img) loss = criterion(output, label) loss.backward() optim.step() total_loss += loss.item() / len(loader.trainloader) acc += (torch.argmax(output, dim=1) == label).sum().item() / len(dataset.trainset) pbar.set_postfix(loss=total_loss, acc=f"{acc * 100:.2f}%") model.eval() with tqdm(loader.validloader, desc="Valid") as pbar: total_loss = 0.0 acc = 0.0 with torch.no_grad(): for img, label in pbar: output = model(img) loss = criterion(output, label) total_loss += loss.item() / len(loader.validloader) acc += (torch.argmax(output, dim=1) == label).sum().item() / len(dataset.validset) pbar.set_postfix(loss=total_loss, acc=f"{acc * 100:.2f}%") if acc > best_acc: torch.save(model.state_dict(), f"{path}/best.pt") tqdm.write("saved best") best_acc = acc if save_all: torch.save(model.state_dict(), f"{path}/mnist_{epoch+1:02d}.pt")
def test_memorize_minibatch(self): for db_name in self.db_names: db_info = get_db_info(db_name) train_data, val_data, _ = get_train_val_test_datasets( dataset_name=db_name, train_test_split='use_full_train', encoders=dict(CATEGORICAL='CategoricalOrdinalEnc', SCALAR='ScalarRobustScalerEnc', DATETIME='DatetimeScalarEnc', LATLONG='LatLongScalarEnc', TEXT='TextSummaryScalarEnc'), ) train_loader = get_dataloader( dataset=train_data, batch_size=256, sampler_class_name='SequentialSampler', num_workers=0, max_nodes_per_graph=False) writer = DummyWriter() model = GCN(writer, db_info=db_info, hidden_dim=256, n_init_layers=3, activation_class_name='SELU', activation_class_kwargs={}, loss_class_kwargs={}, loss_class_name='CrossEntropyLoss', p_dropout=0.0, drop_whole_embeddings=True, n_layers=3, readout_class_name='AvgPooling', readout_kwargs={}) if torch.cuda.is_available(): model.cuda() model.device = torch.device('cuda:0') else: model.device = torch.device('cpu') model.train() optimizer = AdamW(model.parameters(), lr=0.001, weight_decay=0.0) bdgl, features, label = next(iter(train_loader)) recursive_to((bdgl, features, label), model.device) for _ in tqdm(range(200)): optimizer.zero_grad() output = model(bdgl, features) loss = model.loss_fxn(output, label) if loss < 1e-4: break loss.backward() optimizer.step() else: tqdm.write(f'Loss: {loss}') self.fail("Didn't memorize minibatch")
def decompose_rank(self, kernel, lr=5e-3, steps=600): """ GEP decompose standard convolution kernel with different rank :param conv: standard convolution kernel :param lr: learning rate :param steps: training steps for decomposing """ kernel.requires_grad = False param = {self.dw0.weight, self.pw0.weight} for i in range(self.rank): getattr(self, 'pw' + str(i)).weight.requires_grad = True getattr(self, 'dw' + str(i)).weight.requires_grad = True if i != 0: param.add(getattr(self, 'pw' + str(i)).weight) param.add(getattr(self, 'pw' + str(i)).weight) criterion = nn.MSELoss() optimizer = AdamW(param, lr=lr) st = time.time() for s in range(steps): if steps in (400, 700): lr = lr / 10 optimizer = AdamW(param, lr=lr) optimizer.zero_grad() for i in range(self.rank): if i == 0: kernel_pred = getattr(self, \ 'pw' + str(i)).weight.cuda() * \ getattr(self, 'dw' + str(i)).weight.cuda() else: kernel_pred += getattr(self, \ 'pw' + str(i)).weight.cuda() * getattr(self, \ 'dw' + str(i)).weight.cuda() loss = criterion(kernel_pred, kernel.cuda()) loss.backward() optimizer.step() if s % 100 == 99: print('step %d: loss = %f, time = %d' % ((s + 1), loss, (time.time() - st))) st = time.time()
def train_deembeders(self, tuples: List[Tuple[torch.Tensor, PDGEmbedder, PDGDeembedder]], epochs: int) -> List[float]: acc_list = [] for tuple in tuples: lab_data, embedder, deembedder = tuple deembed_optimizer = AdamW(deembedder.parameters(), lr=1e-4) deemb_loss = MSELoss() num_classes = len(lab_data) real_one_hot = func.one_hot(lab_data, num_classes=num_classes).float() for param in embedder.parameters(): param.requires_grad = False gen_one_hot = None for i in range(epochs): deembed_optimizer.zero_grad() embed = embedder(lab_data) gen_one_hot = deembedder(embed) err_deemb = deemb_loss(real_one_hot, gen_one_hot) err_deemb.backward() deembed_optimizer.step() acc = 0 gen_one_hot = (gen_one_hot > .5).int() diffs = torch.eq(real_one_hot, gen_one_hot).all(dim=1).int() size = len(diffs) acc += diffs.int().sum().float() acc /= size for param in embedder.parameters(): param.requires_grad = True acc_list.append(acc) return acc_list
def main( data_dir, save_path, batch_size, n_workers, valid_steps, warmup_steps, total_steps, save_steps, ): """Main function.""" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"[Info]: Use {device} now!") train_loader, valid_loader, speaker_num = get_dataloader(data_dir, batch_size, n_workers) train_iterator = iter(train_loader) print(f"[Info]: Finish loading data!",flush = True) model = Classifier(n_spks=speaker_num).to(device) criterion = nn.CrossEntropyLoss() optimizer = AdamW(model.parameters(), lr=1e-3) scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps) print(f"[Info]: Finish creating model!",flush = True) best_accuracy = -1.0 best_state_dict = None pbar = tqdm(total=valid_steps, ncols=0, desc="Train", unit=" step") for step in range(total_steps): # Get data try: batch = next(train_iterator) except StopIteration: train_iterator = iter(train_loader) batch = next(train_iterator) loss, accuracy = model_fn(batch, model, criterion, device) batch_loss = loss.item() batch_accuracy = accuracy.item() # Updata model loss.backward() optimizer.step() scheduler.step() optimizer.zero_grad() # Log pbar.update() pbar.set_postfix( loss=f"{batch_loss:.2f}", accuracy=f"{batch_accuracy:.2f}", step=step + 1, ) # Do validation if (step + 1) % valid_steps == 0: pbar.close() valid_accuracy = valid(valid_loader, model, criterion, device) # keep the best model if valid_accuracy > best_accuracy: best_accuracy = valid_accuracy best_state_dict = model.state_dict() pbar = tqdm(total=valid_steps, ncols=0, desc="Train", unit=" step") # Save the best model so far. if (step + 1) % save_steps == 0 and best_state_dict is not None: torch.save(best_state_dict, save_path) pbar.write(f"Step {step + 1}, best model saved. (accuracy={best_accuracy:.4f})") pbar.close()
class Distiller: def __init__(self, params: dict, dataset: LmSeqsDataset, token_probs: torch.tensor, student: nn.Module, teacher: nn.Module): logger.info('Initializing Distiller') self.params = params self.dump_path = params.dump_path self.multi_gpu = params.multi_gpu self.fp16 = params.fp16 self.student = student self.teacher = teacher self.student_config = student.config self.vocab_size = student.config.vocab_size if params.n_gpu <= 1: sampler = RandomSampler(dataset) else: sampler = DistributedSampler(dataset) if params.group_by_size: groups = create_lengths_groups(lengths=dataset.lengths, k=params.max_model_input_size) sampler = GroupedBatchSampler(sampler=sampler, group_ids=groups, batch_size=params.batch_size) else: sampler = BatchSampler(sampler=sampler, batch_size=params.batch_size, drop_last=False) self.dataloader = DataLoader(dataset=dataset, batch_sampler=sampler, collate_fn=dataset.batch_sequences) self.temperature = params.temperature assert self.temperature > 0. self.alpha_ce = params.alpha_ce self.alpha_mlm = params.alpha_mlm self.alpha_clm = params.alpha_clm self.alpha_mse = params.alpha_mse self.alpha_cos = params.alpha_cos self.mlm = params.mlm if self.mlm: logger.info(f'Using MLM loss for LM step.') self.mlm_mask_prop = params.mlm_mask_prop assert 0.0 <= self.mlm_mask_prop <= 1.0 assert params.word_mask + params.word_keep + params.word_rand == 1.0 self.pred_probs = torch.FloatTensor( [params.word_mask, params.word_keep, params.word_rand]) self.pred_probs = self.pred_probs.to( f'cuda:{params.local_rank}' ) if params.n_gpu > 0 else self.pred_probs self.token_probs = token_probs.to( f'cuda:{params.local_rank}' ) if params.n_gpu > 0 else token_probs if self.fp16: self.pred_probs = self.pred_probs.half() self.token_probs = self.token_probs.half() else: logger.info(f'Using CLM loss for LM step.') self.epoch = 0 self.n_iter = 0 self.n_total_iter = 0 self.n_sequences_epoch = 0 self.total_loss_epoch = 0 self.last_loss = 0 self.last_loss_ce = 0 self.last_loss_mlm = 0 self.last_loss_clm = 0 if self.alpha_mse > 0.: self.last_loss_mse = 0 if self.alpha_cos > 0.: self.last_loss_cos = 0 self.last_log = 0 self.ce_loss_fct = nn.KLDivLoss(reduction='batchmean') self.lm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1) if self.alpha_mse > 0.: self.mse_loss_fct = nn.MSELoss(reduction='sum') if self.alpha_cos > 0.: self.cosine_loss_fct = nn.CosineEmbeddingLoss(reduction='mean') logger.info('--- Initializing model optimizer') assert params.gradient_accumulation_steps >= 1 self.num_steps_epoch = len(self.dataloader) num_train_optimization_steps = int( self.num_steps_epoch / params.gradient_accumulation_steps * params.n_epoch) + 1 no_decay = ['bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [ p for n, p in student.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad ], 'weight_decay': params.weight_decay }, { 'params': [ p for n, p in student.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad ], 'weight_decay': 0.0 }] logger.info( "------ Number of trainable parameters (student): %i" % sum([ p.numel() for p in self.student.parameters() if p.requires_grad ])) logger.info("------ Number of parameters (student): %i" % sum([p.numel() for p in self.student.parameters()])) self.optimizer = AdamW(optimizer_grouped_parameters, lr=params.learning_rate, eps=params.adam_epsilon, betas=(0.9, 0.98)) warmup_steps = math.ceil(num_train_optimization_steps * params.warmup_prop) self.scheduler = get_linear_schedule_with_warmup( self.optimizer, num_warmup_steps=warmup_steps, num_training_steps=num_train_optimization_steps) if self.fp16: try: from apex import amp except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use fp16 training." ) logger.info( f"Using fp16 training: {self.params.fp16_opt_level} level") self.student, self.optimizer = amp.initialize( self.student, self.optimizer, opt_level=self.params.fp16_opt_level) self.teacher = self.teacher.half() if self.multi_gpu: if self.fp16: from apex.parallel import DistributedDataParallel logger.info( "Using apex.parallel.DistributedDataParallel for distributed training." ) self.student = DistributedDataParallel(self.student) else: from torch.nn.parallel import DistributedDataParallel logger.info( "Using nn.parallel.DistributedDataParallel for distributed training." ) self.student = DistributedDataParallel( self.student, device_ids=[params.local_rank], output_device=params.local_rank, find_unused_parameters=True) self.is_master = params.is_master if self.is_master: logger.info('--- Initializing Tensorboard') self.tensorboard = SummaryWriter( log_dir=os.path.join(self.dump_path, 'log', 'train')) self.tensorboard.add_text(tag='config/training', text_string=str(self.params), global_step=0) self.tensorboard.add_text(tag='config/student', text_string=str(self.student_config), global_step=0) def prepare_batch_mlm(self, batch): """ Prepare the batch: from the token_ids and the lenghts, compute the attention mask and the masked label for MLM. Input: ------ batch: `Tuple` token_ids: `torch.tensor(bs, seq_length)` - The token ids for each of the sequence. It is padded. lengths: `torch.tensor(bs)` - The lengths of each of the sequences in the batch. Output: ------- token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM. attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention. mlm_labels: `torch.tensor(bs, seq_length)` - The masked languge modeling labels. There is a -1 where there is nothing to predict. """ token_ids, lengths = batch token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths) assert token_ids.size(0) == lengths.size(0) attn_mask = (torch.arange(token_ids.size(1), dtype=torch.long, device=lengths.device) < lengths[:, None]) bs, max_seq_len = token_ids.size() mlm_labels = token_ids.new(token_ids.size()).copy_(token_ids) x_prob = self.token_probs[token_ids.flatten()] n_tgt = math.ceil(self.mlm_mask_prop * lengths.sum().item()) tgt_ids = torch.multinomial(x_prob / x_prob.sum(), n_tgt, replacement=False) pred_mask = torch.zeros( bs * max_seq_len, dtype=torch.bool, device=token_ids.device ) # previously `dtype=torch.uint8`, cf pytorch 1.2.0 compatibility pred_mask[tgt_ids] = 1 pred_mask = pred_mask.view(bs, max_seq_len) pred_mask[token_ids == self.params.special_tok_ids['pad_token']] = 0 # mask a number of words == 0 [8] (faster with fp16) if self.fp16: n1 = pred_mask.sum().item() if n1 > 8: pred_mask = pred_mask.view(-1) n2 = max(n1 % 8, 8 * (n1 // 8)) if n2 != n1: pred_mask[torch.nonzero(pred_mask).view(-1)[:n1 - n2]] = 0 pred_mask = pred_mask.view(bs, max_seq_len) assert pred_mask.sum().item() % 8 == 0, pred_mask.sum().item() _token_ids_real = token_ids[pred_mask] _token_ids_rand = _token_ids_real.clone().random_(self.vocab_size) _token_ids_mask = _token_ids_real.clone().fill_( self.params.special_tok_ids['mask_token']) probs = torch.multinomial(self.pred_probs, len(_token_ids_real), replacement=True) _token_ids = _token_ids_mask * ( probs == 0).long() + _token_ids_real * ( probs == 1).long() + _token_ids_rand * (probs == 2).long() token_ids = token_ids.masked_scatter(pred_mask, _token_ids) mlm_labels[ ~pred_mask] = -1 # previously `mlm_labels[1-pred_mask] = -1`, cf pytorch 1.2.0 compatibility # sanity checks assert 0 <= token_ids.min() <= token_ids.max() < self.vocab_size return token_ids, attn_mask, mlm_labels def prepare_batch_clm(self, batch): """ Prepare the batch: from the token_ids and the lenghts, compute the attention mask and the labels for CLM. Input: ------ batch: `Tuple` token_ids: `torch.tensor(bs, seq_length)` - The token ids for each of the sequence. It is padded. lengths: `torch.tensor(bs)` - The lengths of each of the sequences in the batch. Output: ------- token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM. attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention. clm_labels: `torch.tensor(bs, seq_length)` - The causal languge modeling labels. There is a -1 where there is nothing to predict. """ token_ids, lengths = batch token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths) assert token_ids.size(0) == lengths.size(0) attn_mask = (torch.arange(token_ids.size(1), dtype=torch.long, device=lengths.device) < lengths[:, None]) clm_labels = token_ids.new(token_ids.size()).copy_(token_ids) clm_labels[ ~attn_mask] = -1 # previously `clm_labels[1-attn_mask] = -1`, cf pytorch 1.2.0 compatibility # sanity checks assert 0 <= token_ids.min() <= token_ids.max() < self.vocab_size return token_ids, attn_mask, clm_labels def round_batch(self, x: torch.tensor, lengths: torch.tensor): """ For float16 only. Sub-sample sentences in a batch, and add padding, so that each dimension is a multiple of 8. Input: ------ x: `torch.tensor(bs, seq_length)` - The token ids. lengths: `torch.tensor(bs, seq_length)` - The lengths of each of the sequence in the batch. Output: ------- x: `torch.tensor(new_bs, new_seq_length)` - The updated token ids. lengths: `torch.tensor(new_bs, new_seq_length)` - The updated lengths. """ if not self.fp16 or len(lengths) < 8: return x, lengths # number of sentences == 0 [8] bs1 = len(lengths) bs2 = 8 * (bs1 // 8) assert bs2 > 0 and bs2 % 8 == 0 if bs1 != bs2: idx = torch.randperm(bs1)[:bs2] lengths = lengths[idx] slen = lengths.max().item() x = x[idx, :slen] else: idx = None # sequence length == 0 [8] ml1 = x.size(1) if ml1 % 8 != 0: pad = 8 - (ml1 % 8) ml2 = ml1 + pad if self.mlm: pad_id = self.params.special_tok_ids['pad_token'] else: pad_id = self.params.special_tok_ids['unk_token'] padding_tensor = torch.zeros(bs2, pad, dtype=torch.long, device=x.device).fill_(pad_id) x = torch.cat([x, padding_tensor], 1) assert x.size() == (bs2, ml2) assert x.size(0) % 8 == 0 assert x.size(1) % 8 == 0 return x, lengths def train(self): """ The real training loop. """ if self.is_master: logger.info('Starting training') self.last_log = time.time() self.student.train() self.teacher.eval() for _ in range(self.params.n_epoch): if self.is_master: logger.info( f'--- Starting epoch {self.epoch}/{self.params.n_epoch-1}') if self.multi_gpu: torch.distributed.barrier() iter_bar = tqdm(self.dataloader, desc="-Iter", disable=self.params.local_rank not in [-1, 0]) for batch in iter_bar: if self.params.n_gpu > 0: batch = tuple( t.to(f'cuda:{self.params.local_rank}') for t in batch) if self.mlm: token_ids, attn_mask, lm_labels = self.prepare_batch_mlm( batch=batch) else: token_ids, attn_mask, lm_labels = self.prepare_batch_clm( batch=batch) self.step(input_ids=token_ids, attention_mask=attn_mask, lm_labels=lm_labels) iter_bar.update() iter_bar.set_postfix({ 'Last_loss': f'{self.last_loss:.2f}', 'Avg_cum_loss': f'{self.total_loss_epoch/self.n_iter:.2f}' }) iter_bar.close() if self.is_master: logger.info( f'--- Ending epoch {self.epoch}/{self.params.n_epoch-1}') self.end_epoch() if self.is_master: logger.info(f'Save very last checkpoint as `pytorch_model.bin`.') self.save_checkpoint(checkpoint_name=f'pytorch_model.bin') logger.info('Training is finished') def step(self, input_ids: torch.tensor, attention_mask: torch.tensor, lm_labels: torch.tensor): """ One optimization step: forward of student AND teacher, backward on the loss (for gradient accumulation), and possibly a parameter update (depending on the gradient accumulation). Input: ------ input_ids: `torch.tensor(bs, seq_length)` - The token ids. attention_mask: `torch.tensor(bs, seq_length)` - The attention mask for self attention. lm_labels: `torch.tensor(bs, seq_length)` - The language modeling labels (mlm labels for MLM and clm labels for CLM). """ if self.mlm: s_logits, s_hidden_states = self.student( input_ids=input_ids, attention_mask=attention_mask) # (bs, seq_length, voc_size) with torch.no_grad(): t_logits, t_hidden_states = self.teacher( input_ids=input_ids, attention_mask=attention_mask ) # (bs, seq_length, voc_size) else: s_logits, _, s_hidden_states = self.student( input_ids=input_ids, attention_mask=None) # (bs, seq_length, voc_size) with torch.no_grad(): t_logits, _, t_hidden_states = self.teacher( input_ids=input_ids, attention_mask=None) # (bs, seq_length, voc_size) assert s_logits.size() == t_logits.size() #https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100 #https://github.com/peterliht/knowledge-distillation-pytorch/issues/2 if self.params.restrict_ce_to_mask: mask = (lm_labels > -1).unsqueeze(-1).expand_as( s_logits) # (bs, seq_lenth, voc_size) else: mask = attention_mask.unsqueeze(-1).expand_as( s_logits) # (bs, seq_lenth, voc_size) s_logits_slct = torch.masked_select( s_logits, mask) # (bs * seq_length * voc_size) modulo the 1s in mask s_logits_slct = s_logits_slct.view(-1, s_logits.size( -1)) # (bs * seq_length, voc_size) modulo the 1s in mask t_logits_slct = torch.masked_select( t_logits, mask) # (bs * seq_length * voc_size) modulo the 1s in mask t_logits_slct = t_logits_slct.view(-1, s_logits.size( -1)) # (bs * seq_length, voc_size) modulo the 1s in mask assert t_logits_slct.size() == s_logits_slct.size() loss_ce = self.ce_loss_fct( F.log_softmax(s_logits_slct / self.temperature, dim=-1), F.softmax(t_logits_slct / self.temperature, dim=-1)) * (self.temperature)**2 loss = self.alpha_ce * loss_ce if self.alpha_mlm > 0.: loss_mlm = self.lm_loss_fct(s_logits.view(-1, s_logits.size(-1)), lm_labels.view(-1)) loss += self.alpha_mlm * loss_mlm if self.alpha_clm > 0.: shift_logits = s_logits[..., :-1, :].contiguous() shift_labels = lm_labels[..., 1:].contiguous() loss_clm = self.lm_loss_fct( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) loss += self.alpha_clm * loss_clm if self.alpha_mse > 0.: loss_mse = self.mse_loss_fct( s_logits_slct, t_logits_slct) / s_logits_slct.size( 0) # Reproducing batchmean reduction loss += self.alpha_mse * loss_mse if self.alpha_cos > 0.: s_hidden_states = s_hidden_states[-1] # (bs, seq_length, dim) t_hidden_states = t_hidden_states[-1] # (bs, seq_length, dim) mask = attention_mask.unsqueeze(-1).expand_as( s_hidden_states) # (bs, seq_length, dim) assert s_hidden_states.size() == t_hidden_states.size() dim = s_hidden_states.size(-1) s_hidden_states_slct = torch.masked_select( s_hidden_states, mask) # (bs * seq_length * dim) s_hidden_states_slct = s_hidden_states_slct.view( -1, dim) # (bs * seq_length, dim) t_hidden_states_slct = torch.masked_select( t_hidden_states, mask) # (bs * seq_length * dim) t_hidden_states_slct = t_hidden_states_slct.view( -1, dim) # (bs * seq_length, dim) target = s_hidden_states_slct.new( s_hidden_states_slct.size(0)).fill_(1) # (bs * seq_length,) loss_cos = self.cosine_loss_fct(s_hidden_states_slct, t_hidden_states_slct, target) loss += self.alpha_cos * loss_cos self.total_loss_epoch += loss.item() self.last_loss = loss.item() self.last_loss_ce = loss_ce.item() if self.alpha_mlm > 0.: self.last_loss_mlm = loss_mlm.item() if self.alpha_clm > 0.: self.last_loss_clm = loss_clm.item() if self.alpha_mse > 0.: self.last_loss_mse = loss_mse.item() if self.alpha_cos > 0.: self.last_loss_cos = loss_cos.item() self.optimize(loss) self.n_sequences_epoch += input_ids.size(0) def optimize(self, loss): """ Normalization on the loss (gradient accumulation or distributed training), followed by backward pass on the loss, possibly followed by a parameter update (depending on the gradient accumulation). Also update the metrics for tensorboard. """ # Check for NaN if (loss != loss).data.any(): logger.error('NaN detected') exit() if self.multi_gpu: loss = loss.mean() if self.params.gradient_accumulation_steps > 1: loss = loss / self.params.gradient_accumulation_steps if self.fp16: from apex import amp with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() self.iter() if self.n_iter % self.params.gradient_accumulation_steps == 0: if self.fp16: torch.nn.utils.clip_grad_norm_( amp.master_params(self.optimizer), self.params.max_grad_norm) else: torch.nn.utils.clip_grad_norm_(self.student.parameters(), self.params.max_grad_norm) self.optimizer.step() self.optimizer.zero_grad() self.scheduler.step() def iter(self): """ Update global counts, write to tensorboard and save checkpoint. """ self.n_iter += 1 self.n_total_iter += 1 if self.n_total_iter % self.params.log_interval == 0: self.log_tensorboard() self.last_log = time.time() if self.n_total_iter % self.params.checkpoint_interval == 0: self.save_checkpoint() def log_tensorboard(self): """ Log into tensorboard. Only by the master process. """ if not self.is_master: return for param_name, param in self.student.named_parameters(): self.tensorboard.add_scalar(tag='parameter_mean/' + param_name, scalar_value=param.data.mean(), global_step=self.n_total_iter) self.tensorboard.add_scalar(tag='parameter_std/' + param_name, scalar_value=param.data.std(), global_step=self.n_total_iter) if param.grad is None: continue self.tensorboard.add_scalar(tag="grad_mean/" + param_name, scalar_value=param.grad.data.mean(), global_step=self.n_total_iter) self.tensorboard.add_scalar(tag="grad_std/" + param_name, scalar_value=param.grad.data.std(), global_step=self.n_total_iter) self.tensorboard.add_scalar(tag="losses/cum_avg_loss_epoch", scalar_value=self.total_loss_epoch / self.n_iter, global_step=self.n_total_iter) self.tensorboard.add_scalar(tag="losses/loss", scalar_value=self.last_loss, global_step=self.n_total_iter) self.tensorboard.add_scalar(tag="losses/loss_ce", scalar_value=self.last_loss_ce, global_step=self.n_total_iter) if self.alpha_mlm > 0.: self.tensorboard.add_scalar(tag="losses/loss_mlm", scalar_value=self.last_loss_mlm, global_step=self.n_total_iter) if self.alpha_clm > 0.: self.tensorboard.add_scalar(tag="losses/loss_clm", scalar_value=self.last_loss_clm, global_step=self.n_total_iter) if self.alpha_mse > 0.: self.tensorboard.add_scalar(tag="losses/loss_mse", scalar_value=self.last_loss_mse, global_step=self.n_total_iter) if self.alpha_cos > 0.: self.tensorboard.add_scalar(tag="losses/loss_cos", scalar_value=self.last_loss_cos, global_step=self.n_total_iter) self.tensorboard.add_scalar(tag="learning_rate/lr", scalar_value=self.scheduler.get_lr()[0], global_step=self.n_total_iter) self.tensorboard.add_scalar( tag="global/memory_usage", scalar_value=psutil.virtual_memory()._asdict()['used'] / 1_000_000, global_step=self.n_total_iter) self.tensorboard.add_scalar(tag="global/speed", scalar_value=time.time() - self.last_log, global_step=self.n_total_iter) def end_epoch(self): """ Finally arrived at the end of epoch (full pass on dataset). Do some tensorboard logging and checkpoint saving. """ logger.info( f'{self.n_sequences_epoch} sequences have been trained during this epoch.' ) if self.is_master: self.save_checkpoint( checkpoint_name=f'model_epoch_{self.epoch}.pth') self.tensorboard.add_scalar(tag='epoch/loss', scalar_value=self.total_loss_epoch / self.n_iter, global_step=self.epoch) self.epoch += 1 self.n_sequences_epoch = 0 self.n_iter = 0 self.total_loss_epoch = 0 def save_checkpoint(self, checkpoint_name: str = 'checkpoint.pth'): """ Save the current state. Only by the master process. """ if not self.is_master: return mdl_to_save = self.student.module if hasattr( self.student, 'module') else self.student mdl_to_save.config.save_pretrained(self.dump_path) state_dict = mdl_to_save.state_dict() torch.save(state_dict, os.path.join(self.dump_path, checkpoint_name))
loss = distr_dalle(text, images, return_loss=True) if args.deepspeed: distr_dalle.backward(loss) else: loss.backward() clip_grad_norm_(distr_dalle.parameters(), GRAD_CLIP_NORM) if args.deepspeed: distr_dalle.step() # Gradients are automatically zeroed after the step else: opt.step() opt.zero_grad() # Collective loss, averaged avg_loss = deepspeed_utils.average_all(loss) if deepspeed_utils.is_root_worker(): torch.cuda.empty_cache() log = {} if i % 10 == 0: print(epoch, i, f'loss - {avg_loss.item()}') log = { **log, 'epoch': epoch, 'iter': i, 'loss': avg_loss.item()
class Trainer(): def __init__(self, config, pretrained=True): self.config = config self.model, self.vocab = build_model(config) self.device = config['device'] self.num_iters = config['trainer']['iters'] self.beamsearch = config['predictor']['beamsearch'] self.data_root = config['dataset']['data_root'] self.train_annotation = config['dataset']['train_annotation'] self.valid_annotation = config['dataset']['valid_annotation'] self.dataset_name = config['dataset']['name'] self.batch_size = config['trainer']['batch_size'] self.print_every = config['trainer']['print_every'] self.valid_every = config['trainer']['valid_every'] self.checkpoint = config['trainer']['checkpoint'] self.export_weights = config['trainer']['export'] self.metrics = config['trainer']['metrics'] logger = config['trainer']['log'] if logger: self.logger = Logger(logger) if pretrained: weight_file = download_weights(**config['pretrain'], quiet=config['quiet']) self.load_weights(weight_file) self.iter = 0 self.optimizer = AdamW(self.model.parameters(), betas=(0.9, 0.98), eps=1e-09) self.scheduler = OneCycleLR(self.optimizer, **config['optimizer']) # self.optimizer = ScheduledOptim( # Adam(self.model.parameters(), betas=(0.9, 0.98), eps=1e-09), # #config['transformer']['d_model'], # 512, # **config['optimizer']) self.criterion = LabelSmoothingLoss(len(self.vocab), padding_idx=self.vocab.pad, smoothing=0.1) transforms = ImgAugTransform() self.train_gen = self.data_gen('train_{}'.format(self.dataset_name), self.data_root, self.train_annotation, transform=transforms) if self.valid_annotation: self.valid_gen = self.data_gen( 'valid_{}'.format(self.dataset_name), self.data_root, self.valid_annotation) self.train_losses = [] def train(self): total_loss = 0 total_loader_time = 0 total_gpu_time = 0 best_acc = 0 data_iter = iter(self.train_gen) for i in range(self.num_iters): self.iter += 1 start = time.time() try: batch = next(data_iter) except StopIteration: data_iter = iter(self.train_gen) batch = next(data_iter) total_loader_time += time.time() - start start = time.time() loss = self.step(batch) total_gpu_time += time.time() - start total_loss += loss self.train_losses.append((self.iter, loss)) if self.iter % self.print_every == 0: info = 'iter: {:06d} - train loss: {:.3f} - lr: {:.2e} - load time: {:.2f} - gpu time: {:.2f}'.format( self.iter, total_loss / self.print_every, self.optimizer.param_groups[0]['lr'], total_loader_time, total_gpu_time) total_loss = 0 total_loader_time = 0 total_gpu_time = 0 print(info) self.logger.log(info) if self.valid_annotation and self.iter % self.valid_every == 0: val_loss = self.validate() acc_full_seq, acc_per_char = self.precision(self.metrics) info = 'iter: {:06d} - valid loss: {:.3f} - acc full seq: {:.4f} - acc per char: {:.4f}'.format( self.iter, val_loss, acc_full_seq, acc_per_char) print(info) self.logger.log(info) if acc_full_seq > best_acc: self.save_weights(self.export_weights) best_acc = acc_full_seq def validate(self): self.model.eval() total_loss = [] with torch.no_grad(): for step, batch in enumerate(self.valid_gen): batch = self.batch_to_device(batch) img, tgt_input, tgt_output, tgt_padding_mask = batch[ 'img'], batch['tgt_input'], batch['tgt_output'], batch[ 'tgt_padding_mask'] outputs = self.model(img, tgt_input, tgt_padding_mask) # loss = self.criterion(rearrange(outputs, 'b t v -> (b t) v'), rearrange(tgt_output, 'b o -> (b o)')) outputs = outputs.flatten(0, 1) tgt_output = tgt_output.flatten() loss = self.criterion(outputs, tgt_output) total_loss.append(loss.item()) del outputs del loss total_loss = np.mean(total_loss) self.model.train() return total_loss def predict(self, sample=None): pred_sents = [] actual_sents = [] img_files = [] for batch in self.valid_gen: batch = self.batch_to_device(batch) if self.beamsearch: translated_sentence = batch_translate_beam_search( batch['img'], self.model) else: translated_sentence = translate(batch['img'], self.model) pred_sent = self.vocab.batch_decode(translated_sentence.tolist()) actual_sent = self.vocab.batch_decode(batch['tgt_output'].tolist()) img_files.extend(batch['filenames']) pred_sents.extend(pred_sent) actual_sents.extend(actual_sent) if sample != None and len(pred_sents) > sample: break return pred_sents, actual_sents, img_files def precision(self, sample=None): pred_sents, actual_sents, _ = self.predict(sample=sample) acc_full_seq = compute_accuracy(actual_sents, pred_sents, mode='full_sequence') acc_per_char = compute_accuracy(actual_sents, pred_sents, mode='per_char') return acc_full_seq, acc_per_char def visualize_prediction(self, sample=16, errorcase=False, fontname='serif', fontsize=16): pred_sents, actual_sents, img_files = self.predict(sample) if errorcase: wrongs = [] for i in range(len(img_files)): if pred_sents[i] != actual_sents[i]: wrongs.append(i) pred_sents = [pred_sents[i] for i in wrongs] actual_sents = [actual_sents[i] for i in wrongs] img_files = [img_files[i] for i in wrongs] img_files = img_files[:sample] fontdict = {'family': fontname, 'size': fontsize} for vis_idx in range(0, len(img_files)): img_path = img_files[vis_idx] pred_sent = pred_sents[vis_idx] actual_sent = actual_sents[vis_idx] img = Image.open(open(img_path, 'rb')) plt.figure() plt.imshow(img) plt.title('pred: {} - actual: {}'.format(pred_sent, actual_sent), loc='left', fontdict=fontdict) plt.axis('off') plt.show() def visualize_dataset(self, sample=16, fontname='serif'): n = 0 for batch in self.train_gen: for i in range(self.batch_size): img = batch['img'][i].numpy().transpose(1, 2, 0) sent = self.vocab.decode(batch['tgt_input'].T[i].tolist()) plt.figure() plt.title('sent: {}'.format(sent), loc='center', fontname=fontname) plt.imshow(img) plt.axis('off') n += 1 if n >= sample: plt.show() return def load_checkpoint(self, filename): checkpoint = torch.load(filename) optim = ScheduledOptim( Adam(self.model.parameters(), betas=(0.9, 0.98), eps=1e-09), self.config['transformer']['d_model'], **self.config['optimizer']) self.optimizer.load_state_dict(checkpoint['optimizer']) self.model.load_state_dict(checkpoint['state_dict']) self.iter = checkpoint['iter'] self.train_losses = checkpoint['train_losses'] def save_checkpoint(self, filename): state = { 'iter': self.iter, 'state_dict': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'train_losses': self.train_losses } path, _ = os.path.split(filename) os.makedirs(path, exist_ok=True) torch.save(state, filename) def load_weights(self, filename): state_dict = torch.load(filename, map_location=torch.device(self.device)) for name, param in self.model.named_parameters(): if name not in state_dict: print('{} not found'.format(name)) elif state_dict[name].shape != param.shape: print('{} missmatching shape'.format(name)) del state_dict[name] self.model.load_state_dict(state_dict, strict=False) def save_weights(self, filename): path, _ = os.path.split(filename) os.makedirs(path, exist_ok=True) torch.save(self.model.state_dict(), filename) def batch_to_device(self, batch): img = batch['img'].to(self.device, non_blocking=True) tgt_input = batch['tgt_input'].to(self.device, non_blocking=True) tgt_output = batch['tgt_output'].to(self.device, non_blocking=True) tgt_padding_mask = batch['tgt_padding_mask'].to(self.device, non_blocking=True) batch = { 'img': img, 'tgt_input': tgt_input, 'tgt_output': tgt_output, 'tgt_padding_mask': tgt_padding_mask, 'filenames': batch['filenames'] } return batch def data_gen(self, lmdb_path, data_root, annotation, transform=None): dataset = OCRDataset( lmdb_path=lmdb_path, root_dir=data_root, annotation_path=annotation, vocab=self.vocab, transform=transform, image_height=self.config['dataset']['image_height'], image_min_width=self.config['dataset']['image_min_width'], image_max_width=self.config['dataset']['image_max_width']) sampler = ClusterRandomSampler(dataset, self.batch_size, True) gen = DataLoader(dataset, batch_size=self.batch_size, sampler=sampler, collate_fn=collate_fn, shuffle=False, drop_last=False, **self.config['dataloader']) return gen def data_gen_v1(self, lmdb_path, data_root, annotation): data_gen = DataGen( data_root, annotation, self.vocab, 'cpu', image_height=self.config['dataset']['image_height'], image_min_width=self.config['dataset']['image_min_width'], image_max_width=self.config['dataset']['image_max_width']) return data_gen def step(self, batch): self.model.train() batch = self.batch_to_device(batch) img, tgt_input, tgt_output, tgt_padding_mask = batch['img'], batch[ 'tgt_input'], batch['tgt_output'], batch['tgt_padding_mask'] outputs = self.model(img, tgt_input, tgt_key_padding_mask=tgt_padding_mask) # loss = self.criterion(rearrange(outputs, 'b t v -> (b t) v'), rearrange(tgt_output, 'b o -> (b o)')) outputs = outputs.view(-1, outputs.size(2)) #flatten(0, 1) tgt_output = tgt_output.view(-1) #flatten() loss = self.criterion(outputs, tgt_output) self.optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1) self.optimizer.step() self.scheduler.step() loss_item = loss.item() return loss_item
def train(args): logger = log.get_logger(__name__) with open(Path(args.config_base_path, args.config).with_suffix(".yaml"), 'r') as f: config = yaml.safe_load(f) train_transforms = transforms.get_train_transforms() val_transforms = transforms.get_val_transforms() logger.info("Loading the dataset...") if config['dataset']['name'] == 'coco_subset': # TODO: Look into train_transforms hiding the objects # Transform in such a way that this can't be the case train_dataset = CocoSubset(config['dataset']['coco_path'], config['dataset']['target_classes'], train_transforms, 'train', config['dataset']['train_val_split']) val_dataset = CocoSubset(config['dataset']['coco_path'], config['dataset']['target_classes'], val_transforms, 'val', config['dataset']['train_val_split']) else: raise ValueError("Dataset name not recognized or implemented") train_loader = DataLoader(train_dataset, config['training']['batch_size'], shuffle=True, collate_fn=data_utils.collate_fn) val_loader = DataLoader(val_dataset, config['training']['batch_size'], shuffle=True, collate_fn=data_utils.collate_fn) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") checkpoint_manager = CheckpointManager(args.config, args.save_every) logger.info("Loading model...") model = models.DETR(config['dataset']['num_classes'], config['model']['dim_model'], config['model']['n_heads'], n_queries=config['model']['n_queries'], head_type=config['model']['head_type']) # TODO: implement scheduler optim = AdamW(model.parameters(), config['training']['lr']) # pending if args.mode == 'pretrained': model.load_demo_state_dict('data/state_dicts/detr_demo.pth') elif args.mode == 'checkpoint': state_dict, optim_dict = checkpoint_manager.load_checkpoint('latest') model.load_state_dict(state_dict) optim.load_state_dict(optim_dict) if args.train_section == 'head': to_train = ['ffn'] elif args.train_section == 'backbone': to_train = ['backbone', 'conv'] else: to_train = ['ffn', 'backbone', 'conv', 'transformer', 'row', 'col', 'object'] # Freeze everything but the modules that are in to_train for name, param in model.named_parameters(): if not any(map(name.startswith, to_train)): param.requires_grad = False model.to(device) matcher = models.HungarianMatcher(config['losses']['lambda_matcher_classes'], config['losses']['lambda_matcher_giou'], config['losses']['lambda_matcher_l1']) loss_fn = models.DETRLoss(config['losses']['lambda_loss_classes'], config['losses']['lambda_loss_giou'], config['losses']['lambda_loss_l1'], config['dataset']['num_classes'], config['losses']['no_class_weight']) # writer = SummaryWriter(log_dir=Path(__file__)/'logs/tensorboard') # maybe image with boxes every now and then # maybe look into add_hparams logger.info("Starting training...") loss_hist = deque() loss_desc = "Loss: n/a" update_every_n_steps = config['training']['effective_batch_size'] // config['training']['batch_size'] steps = 1 starting_epoch = checkpoint_manager.current_epoch for epoch in range(starting_epoch, config['training']['epochs']): epoch_desc = f"Epoch [{epoch}/{config['training']['epochs']}]" for images, labels in tqdm(train_loader, f"{epoch_desc} | {loss_desc}"): images = images.to(device) labels = data_utils.labels_to_device(labels, device) output = model(images) matching_indices = matcher(output, labels) matching_indices = data_utils.indices_to_device(matching_indices, device) loss = loss_fn(output, labels, matching_indices) / update_every_n_steps loss_hist.append(loss.item() * update_every_n_steps) loss.backward() if steps % update_every_n_steps == 0: optim.step() optim.zero_grad() steps += 1 checkpoint_manager.step(model, optim, sum(loss_hist) / len(loss_hist)) loss_desc = f"Loss: {sum(loss_hist)/len(loss_hist)}" loss_hist.clear() if (epoch % args.eval_every == 0) and epoch != 0: validation_loop(model, matcher, val_loader, loss_fn, device) checkpoint_manager.save_checkpoint(model, optim, sum(loss_hist) / len(loss_hist))
class Model: def __init__(self, local_rank=-1, arbitrary=False): if arbitrary == True: self.flownet = IFNet_m() else: self.flownet = IFNet() self.device() self.optimG = AdamW( self.flownet.parameters(), lr=1e-6, weight_decay=1e-3) # use large weight decay may avoid NaN loss self.epe = EPE() self.lap = LapLoss() self.sobel = SOBEL() if local_rank != -1: self.flownet = DDP(self.flownet, device_ids=[local_rank], output_device=local_rank) def train(self): self.flownet.train() def eval(self): self.flownet.eval() def device(self): self.flownet.to(device) def load_model(self, path, rank=0): def convert(param): return { k.replace("module.", ""): v for k, v in param.items() if "module." in k } if rank <= 0: self.flownet.load_state_dict( convert(torch.load('{}/flownet.pkl'.format(path)))) def save_model(self, path, rank=0): if rank == 0: torch.save(self.flownet.state_dict(), '{}/flownet.pkl'.format(path)) def inference(self, img0, img1, scale_list=[4, 2, 1], TTA=False, timestep=0.5): imgs = torch.cat((img0, img1), 1) flow, mask, merged, flow_teacher, merged_teacher, loss_distill = self.flownet( imgs, scale_list, timestep=timestep) if TTA == False: return merged[2] else: flow2, mask2, merged2, flow_teacher2, merged_teacher2, loss_distill2 = self.flownet( imgs.flip(2).flip(3), scale_list, timestep=timestep) return (merged[2] + merged2[2].flip(2).flip(3)) / 2 def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None): for param_group in self.optimG.param_groups: param_group['lr'] = learning_rate img0 = imgs[:, :3] img1 = imgs[:, 3:] if training: self.train() else: self.eval() flow, mask, merged, flow_teacher, merged_teacher, loss_distill = self.flownet( torch.cat((imgs, gt), 1), scale=[4, 2, 1]) loss_l1 = (self.lap(merged[2], gt)).mean() loss_tea = (self.lap(merged_teacher, gt)).mean() if training: self.optimG.zero_grad() loss_G = loss_l1 + loss_tea + loss_distill * 0.01 loss_G.backward() self.optimG.step() else: flow_teacher = flow[2] return merged[2], { 'merged_tea': merged_teacher, 'mask': mask, 'mask_tea': mask, 'flow': flow[2][:, :2], 'flow_tea': flow_teacher, 'loss_l1': loss_l1, 'loss_tea': loss_tea, 'loss_distill': loss_distill, }
def main(): # 如果可以使用GPU运算,则使用GPU,否则使用CPU device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') print("Use " + str(device)) # 图片预处理的方法 img_transform = transforms.Compose([ # 将图片转换为tensor类型并缩放到[0,1]的区间内 transforms.ToTensor(), # 将图片再缩放到[-1.1]的区间内 transforms.Normalize((0.5, ), (0.5, )), ]) # 创建输出文件夹 if not os.path.exists(config.output_path): os.mkdir(config.output_path) # 创建dataset mnist_dataset = Digit_train_Dataset(pd.read_csv("MNIST.csv"), transform=img_transform) # 创建dataloader mnist_loader = DataLoader(dataset=mnist_dataset, batch_size=config.batchSize, shuffle=True) # 从model中获取判别器D和生成器G的网络模型 G_model = get_G_model(config.from_old_model, device, config.G_model_path, config.G_type) D_model = get_D_model(config.from_old_model, device, config.D_model_path) # 定义G和D的优化器,此处使用AdamW优化器,学习率为1e-4 G_optimizer = AdamW(G_model.parameters(), lr=1e-4, weight_decay=1e-6) D_optimizer = AdamW(D_model.parameters(), lr=1e-4, weight_decay=1e-6) # 损失函数 criterion = config.criterion # 记录训练时间 train_start = time.time() # 开始训练的每一个epoch for epoch in range(config.epochs): print("start epoch " + str(epoch + 1) + ":") # 定义一些变量用于记录进度和损失 batch_num = len(mnist_loader) D_loss_sum = 0 G_loss_sum = 0 count = 0 # 从dataloader中提取数据 for index, images in enumerate(mnist_loader): count += 1 # 将图片放入运算设备的内存 images = images.to(device) # 定义真标签,使用标签平滑的策略,生成0.9到1之间的随机数作为真实标签 real_labels = (1 - torch.rand(config.batchSize, 1) / 10).to(device) # 定义假标签,单向平滑,因此不对生成器标签进行平滑处理,全0 fake_labels = Variable(torch.zeros(config.batchSize, 1)).to(device) # 将随机的初始数据喂入生成器生成假图像 img_seeds = torch.randn(config.batchSize, config.img_seed_dim).to(device) fake_images = G_model(img_seeds) # 记录真假标签是否被交换过 exchange_labels = False # 有一定概率在训练判别器时交换label if random.uniform(0, 1) < config.D_train_label_exchange: real_labels, fake_labels = fake_labels, real_labels exchange_labels = True # 训练判断器D D_optimizer.zero_grad() # 用真样本输入判别器 real_output = D_model(images) # 对于数据集末尾的数据,长度不够一个batch size时需要去除过长的真实标签 if len(real_labels) > len(real_output): D_loss_real = criterion(real_output, real_labels[:len(real_output)]) else: D_loss_real = criterion(real_output, real_labels) # 用假样本输入判别器 fake_output = D_model(fake_images) D_loss_fake = criterion(fake_output, fake_labels) # 将真样本与假样本损失相加,得到判别器的损失 D_loss = D_loss_real + D_loss_fake D_loss_sum += D_loss.item() # 重置优化器 D_optimizer.zero_grad() # 用损失更新判别器D D_loss.backward() D_optimizer.step() # 如果之前交换过标签,此时再换回来 if exchange_labels: real_labels, fake_labels = fake_labels, real_labels # 训练生成器G # 将随机种子数喂入生成器G生成假数据 img_seeds = torch.randn(config.batchSize, config.img_seed_dim).to(device) fake_images = G_model(img_seeds) # 将假数据输入判别器 fake_output = D_model(fake_images) # 将假数据的判别结果与真实标签对比得到损失 G_loss = criterion(fake_output, real_labels) G_loss_sum += G_loss.item() # 重置优化器 G_optimizer.zero_grad() # 利用损失更新生成器G G_loss.backward() G_optimizer.step() # 打印程序工作进度 if (index + 1) % 200 == 0: print("Epoch: %2d, Batch: %4d / %4d" % (epoch + 1, index + 1, batch_num)) # 在每个epoch结束时保存模型参数到磁盘文件 torch.save(G_model.state_dict(), config.G_model_path) torch.save(D_model.state_dict(), config.D_model_path) # 在每个epoch结束时输出一组生成器产生的图片到输出文件夹 img_seeds = torch.randn(config.batchSize, config.img_seed_dim).to(device) fake_images = G_model(img_seeds).cuda().data # 将假图像缩放到[0,1]的区间 fake_images = 0.5 * (fake_images + 1) fake_images = fake_images.clamp(0, 1) # 连接所有生成的图片然后用自带的save_image()函数输出到磁盘文件 fake_images = fake_images.view(-1, 1, 28, 28) save_image(fake_images, config.output_path + str(epoch + 1) + '.png') # 打印该epoch的损失,时间等数据用于参考 print("D_loss:", round(D_loss_sum / count, 3)) print("G_loss:", round(G_loss_sum / count, 3)) current_time = time.time() pass_time = int(current_time - train_start) time_string = str(pass_time // 3600) + " hours, " + str( (pass_time % 3600) // 60) + " minutes, " + str( pass_time % 60) + " seconds." print("Time pass:"******"Done.")
class Model: def __init__(self, local_rank=-1): self.flownet = IFNet() self.contextnet = ContextNet() self.fusionnet = FusionNet() self.device() self.optimG = AdamW(itertools.chain(self.flownet.parameters(), self.contextnet.parameters(), self.fusionnet.parameters()), lr=1e-6, weight_decay=1e-5) self.schedulerG = optim.lr_scheduler.CyclicLR(self.optimG, base_lr=1e-6, max_lr=1e-3, step_size_up=8000, cycle_momentum=False) self.epe = EPE() self.ter = Ternary() self.sobel = SOBEL() if local_rank != -1: self.flownet = DDP(self.flownet, device_ids=[local_rank], output_device=local_rank) self.contextnet = DDP(self.contextnet, device_ids=[local_rank], output_device=local_rank) self.fusionnet = DDP(self.fusionnet, device_ids=[local_rank], output_device=local_rank) def train(self): self.flownet.train() self.contextnet.train() self.fusionnet.train() def eval(self): self.flownet.eval() self.contextnet.eval() self.fusionnet.eval() def device(self): self.flownet.to(device) self.contextnet.to(device) self.fusionnet.to(device) def load_model(self, path, rank=-1): def convert(param): if rank == -1: return { k.replace("module.", ""): v for k, v in param.items() if "module." in k } else: return param if rank <= 0: self.flownet.load_state_dict( convert( torch.load('{}/flownet.pkl'.format(path), map_location=device))) self.contextnet.load_state_dict( convert( torch.load('{}/contextnet.pkl'.format(path), map_location=device))) self.fusionnet.load_state_dict( convert( torch.load('{}/unet.pkl'.format(path), map_location=device))) def save_model(self, path, rank): if rank == 0: torch.save(self.flownet.state_dict(), '{}/flownet.pkl'.format(path)) torch.save(self.contextnet.state_dict(), '{}/contextnet.pkl'.format(path)) torch.save(self.fusionnet.state_dict(), '{}/unet.pkl'.format(path)) def predict(self, imgs, flow, training=True, flow_gt=None): img0 = imgs[:, :3] img1 = imgs[:, 3:] flow = F.interpolate( flow, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0 c0 = self.contextnet(img0, flow[:, :2]) c1 = self.contextnet(img1, flow[:, 2:4]) refine_output, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt = self.fusionnet( img0, img1, flow, c0, c1, flow_gt) res = torch.sigmoid(refine_output[:, :3]) * 2 - 1 mask = torch.sigmoid(refine_output[:, 3:4]) merged_img = warped_img0 * mask + warped_img1 * (1 - mask) pred = merged_img + res pred = torch.clamp(pred, 0, 1) if training: return pred, mask, merged_img, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt else: return pred def inference(self, img0, img1): imgs = torch.cat((img0, img1), 1) flow, _ = self.flownet(imgs) return self.predict(imgs, flow, training=False) def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None): for param_group in self.optimG.param_groups: param_group['lr'] = learning_rate if training: self.train() else: self.eval() flow, flow_list = self.flownet(imgs) pred, mask, merged_img, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt = self.predict( imgs, flow, flow_gt=flow_gt) loss_ter = self.ter(pred, gt).mean() if training: with torch.no_grad(): loss_flow = torch.abs(warped_img0_gt - gt).mean() loss_mask = torch.abs(merged_img - gt).sum( 1, True).float().detach() loss_mask = F.interpolate(loss_mask, scale_factor=0.5, mode="bilinear", align_corners=False).detach() flow_gt = (F.interpolate(flow_gt, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5).detach() loss_cons = 0 for i in range(3): loss_cons += self.epe(flow_list[i][:, :2], flow_gt[:, :2], 1) loss_cons += self.epe(flow_list[i][:, 2:4], flow_gt[:, 2:4], 1) loss_cons = loss_cons.mean() * 0.01 else: loss_cons = torch.tensor([0]) loss_flow = torch.abs(warped_img0 - gt).mean() loss_mask = 1 loss_l1 = (((pred - gt)**2 + 1e-6)**0.5).mean() if training: self.optimG.zero_grad() loss_G = loss_l1 + loss_cons + loss_ter loss_G.backward() self.optimG.step() return pred, merged_img, flow, loss_l1, loss_flow, loss_cons, loss_ter, loss_mask
def train_model(self, train_x, train_y, train_mask, dev_x, dev_y, dev_mask, test_x, test_y, test_mask, lr=1e-5, batch_size=16, aux_batch_size=4, use_aux=False, sampling='uniform', all_neg=False, model_path='models'): model_name = "model_{}".format("ns" if use_aux else "base") if use_aux: model_name += "_{}".format("allneg" if all_neg else "normal") all_params = [p for _, p in self.bert.named_parameters()] + [self.linear.weight, self.linear.bias] num_train_steps = len(train_x) // batch_size * 50 optimizer = AdamW(all_params, lr=lr) scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=int(num_train_steps * 0.1), num_training_steps=num_train_steps) train_y_numpy = [l.numpy() for l in train_y] tor = 0 max_score = -1 label_group = make_label_group(train_y_numpy) stack = [] count = 1 steps = 0 train_losses_main = [] train_losses_aux = [] dev_losses_main = [] dev_losses_aux = [] all_idx = set(list(range(train_x.shape[0]))) while tor < 10: if count > 50: break st = time.time() print('epoch ', count) count += 1 self.train() for bx, by, bmask, prg, bidx in self.data_generator(train_x, train_y, train_mask): bx = bx.to(device) by = by.to(device) bmask = bmask.to(device) if use_aux: aux_x, aux_y, aux_label, aux_idx = self.aux_task_sampling( train_x, train_y, by, bidx, label_group, batch=aux_batch_size, at_random=('rand' in sampling), all_neg=all_neg) aux_mask = torch.cuda.LongTensor([train_mask[i].numpy() for i in aux_idx]) aux_x = torch.stack(aux_x).type(torch.LongTensor).cuda() aux_y = torch.cuda.LongTensor(aux_y) all_loss, main_loss, aux_loss = self.calc_loss(bx, aux_x, bmask, aux_mask, by, aux_y, all_neg, use_aux=True) else: all_loss, main_loss = self.calc_loss(bx, None, bmask, None, by, None, False, use_aux=False) optimizer.zero_grad() all_loss.backward() loss_value = all_loss.detach().cpu().numpy() optimizer.step() scheduler.step() print('progress: {:.2f}%, loss = {:.5f}\r'.format(prg * 100, loss_value), end='', flush=True) steps += 1 print('') self.eval() with torch.no_grad(): score, _, losses = self.evaluate(dev_x, dev_y, dev_mask, all_neg=all_neg) score_test, _, _ = self.evaluate(test_x, test_y, test_mask, all_neg=all_neg) print('dev exact match = ', score, flush=True) print('test exact match = ', score_test, flush=True) if max_score < score: max_score = score tor = 0 with torch.no_grad(): max_score_test, preds_test, _ = self.evaluate(test_x, test_y, test_mask) torch.save(self.state_dict(), os.path.join(model_path, model_name) + "_{}".format(len(os.listdir(model_path)))) else: tor += 1 ed = time.time() print('time = ', ed - st) self.train() print('finish') all_losses = { 'main_losses_train': train_losses_main, 'aux_losses_train': train_losses_aux, 'main_losses_dev': dev_losses_main, 'aux_losses_dev': dev_losses_aux} return max_score_test, max_score, all_losses
def run(self): """Run the training""" start = timeit.default_timer() no_improve_count = 0 AdamW_optim = AdamW(self.weights, lr=self.init_lr) SGD_optim = torch.optim.SGD(self.biases, lr=self.init_lr) AdamW_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( AdamW_optim, factor=0.5, patience=100, threshold=0) SGD_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( SGD_optim, factor=0.5, patience=100, threshold=0) while True: rmse = self.evaluate(self.validation_set) learning_rate = AdamW_optim.param_groups[0]['lr'] if learning_rate < self.min_lr or AdamW_scheduler.last_epoch > self.nmax: break # checkpoint if AdamW_scheduler.is_better(rmse, AdamW_scheduler.best): no_improve_count = 0 torch.save(self.nn.state_dict(), self.model_checkpoint) else: no_improve_count += 1 if no_improve_count > self.max_nonimprove: break AdamW_scheduler.step(rmse) SGD_scheduler.step(rmse) if self.tensorboard is not None: self.tensorboard.add_scalar('validation_rmse', rmse, AdamW_scheduler.last_epoch) self.tensorboard.add_scalar('best_validation_rmse', AdamW_scheduler.best, AdamW_scheduler.last_epoch) self.tensorboard.add_scalar('learning_rate', learning_rate, AdamW_scheduler.last_epoch) self.tensorboard.add_scalar('no_improve_count_vs_epoch', no_improve_count, AdamW_scheduler.last_epoch) for i, properties in self.tqdm( enumerate(self.training_set), total=len(self.training_set), desc='epoch {}'.format(AdamW_scheduler.last_epoch)): species = properties['species'].to(self.device) coordinates = properties['coordinates'].to( self.device).float() true_energies = properties['energies'].to( self.device).float() num_atoms = (species >= 0).sum(dim=1, dtype=true_energies.dtype) _, predicted_energies = self.model((species, coordinates)) loss = (self.mse_se(predicted_energies, true_energies) / num_atoms.sqrt()).mean() AdamW_optim.zero_grad() SGD_optim.zero_grad() loss.backward() AdamW_optim.step() SGD_optim.step() # write current batch loss to TensorBoard if self.tensorboard is not None: self.tensorboard.add_scalar( 'batch_loss', loss, AdamW_scheduler.last_epoch * len(self.training_set) + i) # log elapsed time elapsed = round(timeit.default_timer() - start, 2) if self.tensorboard is not None: self.tensorboard.add_scalar('time_vs_epoch', elapsed, AdamW_scheduler.last_epoch)
def prepare_data(dataset): data_path = 'data/{}'.format(dataset) data_full_path = '{}.npz'.format(data_path) if exists(data_full_path): data = np.load(data_full_path, allow_pickle=True) if 'cifar' in dataset or 'celeb' in dataset: return data['xs'], data['hs'], data['ys'], data['ps'] else: return data['xs'], data['ys'], data['ps'] if 'cifar' in dataset: device = 'cuda:0' model = models.resnet152(num_classes=10).to(device) model_path = 'data/cifar_resnet152' transform = Compose([ ToTensor(), Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ]) data = CIFAR10(root='data', download=True, transform=transform) if exists(model_path): model.load_state_dict(load(model_path)) else: loader = DataLoader(data, batch_size=1024, shuffle=True) loss = CrossEntropyLoss() optim = AdamW(model.parameters(), amsgrad=True) for epoch in range(100): epoch_loss = 0 for _data in loader: xs, ys = _data optim.zero_grad() _loss = loss(model(xs.to(device)), ys.to(device)) _loss.backward() optim.step() epoch_loss += _loss.item() print('Epoch {}: Loss: {}'.format(epoch + 1, epoch_loss)) save(model.state_dict(), model_path) xs, ys = data.data, np.array(data.targets) first_digit, second_digit = int(dataset[5]), int(dataset[7]) mask = np.logical_or(ys == first_digit, ys == second_digit) xs = xs[mask] ys = np.array([1 if y == first_digit else 0 for y in ys[mask]]) loader = DataLoader(data, batch_size=1024) hs = None model.eval() for _data in loader: _xs, _ys = _data _mask = (_ys == first_digit) | (_ys == second_digit) if len(_mask) > 0: _hs = resnet_fc(model, _xs[_mask].to(device)) _hs = _hs.cpu().detach().numpy() hs = _hs if hs is None else np.concatenate((hs, _hs), axis=0) clf = LogisticRegression(n_jobs=-1, max_iter=100000) clf.fit(hs, ys) ps = clf.predict_proba(hs)[:, 1] np.savez(data_path, xs=xs, hs=hs, ys=ys, ps=ps) return xs, hs, ys, ps if 'mnist' in dataset: xs, ys = fetch('mnist_784') first_digit, second_digit = dataset[5], dataset[7] elif 'fashion' in dataset: xs, ys = fetch('Fashion-MNIST') first_digit, second_digit = dataset[7], dataset[9] elif 'kuzushi' in dataset: xs, ys = fetch('Kuzushiji-MNIST') first_digit, second_digit = dataset[7], dataset[9] mask = np.logical_or(ys == first_digit, ys == second_digit) xs = xs[mask] / 255 ys = np.array([1 if y == first_digit else 0 for y in ys[mask]]) clf = LogisticRegression(n_jobs=-1, max_iter=100000) clf.fit(xs, ys) ps = clf.predict_proba(xs)[:, 1] np.savez(data_path, xs=xs, ys=ys, ps=ps) return xs, ys, ps
class Seq2seqKpGen(object): """High level model that handles intializing the underlying network architecture, saving, updating examples, and predicting examples. """ # -------------------------------------------------------------------------- # Initialization # -------------------------------------------------------------------------- def __init__(self, args, word_dict, state_dict=None): # Book-keeping. self.args = args self.word_dict = word_dict self.args.vocab_size = len(word_dict) self.updates = 0 self.network = Sequence2Sequence(self.args, self.word_dict) if state_dict: self.network.load_state_dict(state_dict) def activate_fp16(self): if not hasattr(self, 'optimizer'): self.network.half() # for testing only return try: global amp from apex import amp except ImportError: raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") # https://github.com/NVIDIA/apex/issues/227 assert self.optimizer is not None self.network, self.optimizer = amp.initialize(self.network, self.optimizer, opt_level=self.args.fp16_opt_level) def init_optimizer(self, optim_state=None, sched_state=None): def get_constant_schedule_with_warmup(optimizer, num_warmup_steps, last_epoch=-1): def lr_lambda(current_step: int): if current_step < num_warmup_steps: return float(current_step) / float(max(1.0, num_warmup_steps)) return 1.0 return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch) no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ { "params": [p for n, p in self.network.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": self.args.weight_decay, }, {"params": [p for n, p in self.network.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}, ] self.optimizer = AdamW(optimizer_grouped_parameters, lr=self.args.learning_rate) self.scheduler = get_constant_schedule_with_warmup(self.optimizer, self.args.warmup_steps) if optim_state: self.optimizer.load_state_dict(optim_state) if self.args.device.type == 'cuda': for state in self.optimizer.state.values(): for k, v in state.items(): if torch.is_tensor(v): state[k] = v.to(self.args.device) if sched_state: self.scheduler.load_state_dict(sched_state) # -------------------------------------------------------------------------- # Learning # -------------------------------------------------------------------------- def update(self, ex): """Forward a batch of examples; step the optimizer to update weights.""" if not self.optimizer: raise RuntimeError('No optimizer set.') # Train mode self.network.train() source_map, alignment = None, None if self.args.copy_attn: source_map = make_src_map(ex['src_map']).to(self.args.device) alignment = align(ex['alignment']).to(self.args.device) source_rep = ex['source_rep'].to(self.args.device) source_len = ex['source_len'].to(self.args.device) target_rep = ex['target_rep'].to(self.args.device) target_len = ex['target_len'].to(self.args.device) # Run forward ml_loss, loss_per_token = self.network(source=source_rep, source_len=source_len, target=target_rep, target_len=target_len, src_map=source_map, alignment=alignment) loss = ml_loss.mean() if self.args.n_gpu > 1 else ml_loss if self.args.fp16: global amp with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() clip_grad_norm_(amp.master_params(self.optimizer), self.args.grad_clipping) else: loss.backward() clip_grad_norm_(self.network.parameters(), self.args.grad_clipping) self.updates += 1 self.optimizer.step() self.scheduler.step() # Update learning rate schedule self.optimizer.zero_grad() loss_per_token = loss_per_token.mean() if self.args.n_gpu > 1 else loss_per_token loss_per_token = loss_per_token.item() loss_per_token = 10 if loss_per_token > 10 else loss_per_token perplexity = math.exp(loss_per_token) return { 'ml_loss': loss.item(), 'perplexity': perplexity } # -------------------------------------------------------------------------- # Prediction # -------------------------------------------------------------------------- def predict(self, ex, replace_unk=False): """Forward a batch of examples only to get predictions. Args: ex: the batch examples replace_unk: replace `unk` tokens while generating predictions src_raw: raw source (passage); required to replace `unk` term Output: predictions: #batch predicted sequences """ def convert_text_to_string(text): """ Converts a sequence of tokens (string) in a single string. """ out_string = text.replace(" ##", "").strip() return out_string self.network.eval() source_map, alignment = None, None blank, fill = None, None if self.args.copy_attn: source_map = make_src_map(ex['src_map']).to(self.args.device) alignment = align(ex['alignment']).to(self.args.device) blank, fill = collapse_copy_scores(self.word_dict, ex['src_vocab']) source_rep = ex['source_rep'].to(self.args.device) source_len = ex['source_len'].to(self.args.device) decoder_out = self.network(source=source_rep, source_len=source_len, target=None, target_len=None, src_map=source_map, alignment=alignment, max_len=self.args.max_tgt_len, tgt_dict=self.word_dict, blank=blank, fill=fill, source_vocab=ex['src_vocab']) dec_probs = torch.exp(decoder_out['dec_log_probs']) predictions, scores = tens2sen_score(decoder_out['predictions'], dec_probs, self.word_dict, ex['src_vocab']) if replace_unk: for i in range(len(predictions)): enc_dec_attn = decoder_out['attentions'][i] if self.args.model_type == 'transformer': # tgt_len x num_heads x src_len assert enc_dec_attn.dim() == 3 enc_dec_attn = enc_dec_attn.mean(1) predictions[i] = replace_unknown(predictions[i], enc_dec_attn, src_raw=ex['source'][i].tokens) for bidx in range(ex['batch_size']): for i in range(len(predictions[bidx])): if predictions[bidx][i] == constants.KP_SEP: scores[bidx][i] = constants.KP_SEP elif predictions[bidx][i] == constants.PRESENT_EOS: scores[bidx][i] = constants.PRESENT_EOS else: assert isinstance(scores[bidx][i], float) scores[bidx][i] = str(scores[bidx][i]) predictions = [' '.join(item) for item in predictions] scores = [' '.join(item) for item in scores] present_kps = [] absent_kps = [] present_kp_scores = [] absent_kp_scores = [] for bidx in range(ex['batch_size']): keyphrases = predictions[bidx].split(constants.PRESENT_EOS) kp_scores = scores[bidx].split(constants.PRESENT_EOS) pkps = (' %s ' % constants.KP_SEP).join(keyphrases[:-1]) pkp_scores = (' %s ' % constants.KP_SEP).join(kp_scores[:-1]) akps = keyphrases[-1] akp_scores = kp_scores[-1] pre_kps = [] pre_kp_scores = [] for pkp, pkp_s in zip(pkps.split(constants.KP_SEP), pkp_scores.split(constants.KP_SEP)): pkp = pkp.strip() if pkp: pre_kps.append(convert_text_to_string(pkp)) t_scores = [float(i) for i in pkp_s.strip().split()] _score = np.prod(t_scores) / len(t_scores) pre_kp_scores.append(_score) present_kps.append(pre_kps) present_kp_scores.append(pre_kp_scores) abs_kps = [] abs_kp_scores = [] for akp, akp_s in zip(akps.split(constants.KP_SEP), akp_scores.split(constants.KP_SEP)): akp = akp.strip() if akp: abs_kps.append(convert_text_to_string(akp)) t_scores = [float(i) for i in akp_s.strip().split()] _score = np.prod(t_scores) / len(t_scores) abs_kp_scores.append(_score) absent_kps.append(abs_kps) absent_kp_scores.append(abs_kp_scores) return { 'present_kps': present_kps, 'absent_kps': absent_kps, 'present_kp_scores': present_kp_scores, 'absent_kp_scores': absent_kp_scores } # -------------------------------------------------------------------------- # Saving and loading # -------------------------------------------------------------------------- def save(self, filename): network = self.network.module if hasattr(self.network, "module") \ else self.network state_dict = copy.copy(network.state_dict()) params = { 'state_dict': state_dict, 'word_dict': self.word_dict, 'args': self.args, } try: torch.save(params, filename) except BaseException: logger.warning('WARN: Saving failed... continuing anyway.') def checkpoint(self, filename, epoch): network = self.network.module if hasattr(self.network, "module") \ else self.network params = { 'state_dict': network.state_dict(), 'word_dict': self.word_dict, 'args': self.args, 'epoch': epoch, 'updates': self.updates, 'optim_dict': self.optimizer.state_dict(), 'sched_dict': self.scheduler.state_dict(), } try: torch.save(params, filename) except BaseException: logger.warning('WARN: Saving failed... continuing anyway.') @staticmethod def load(filename, new_args=None): logger.info('Loading model %s' % filename) saved_params = torch.load( filename, map_location=lambda storage, loc: storage ) word_dict = saved_params['word_dict'] state_dict = saved_params['state_dict'] args = saved_params['args'] if new_args: args = override_model_args(args, new_args) return Seq2seqKpGen(args, word_dict, state_dict) @staticmethod def load_checkpoint(filename): logger.info('Loading model %s' % filename) saved_params = torch.load( filename, map_location=lambda storage, loc: storage ) word_dict = saved_params['word_dict'] state_dict = saved_params['state_dict'] epoch = saved_params['epoch'] updates = saved_params['updates'] optim_dict = saved_params['optim_dict'] sched_dict = saved_params['sched_dict'] args = saved_params['args'] model = Seq2seqKpGen(args, word_dict, state_dict) model.updates = updates model.init_optimizer(optim_dict, sched_dict) return model, epoch # -------------------------------------------------------------------------- # Runtime # -------------------------------------------------------------------------- def to(self, device): self.network = self.network.to(device) def parallelize(self): self.network = torch.nn.DataParallel(self.network)
class Distiller: def __init__(self, params, dataloader, student, teacher, device): # Initializing Distiller self.params = params self.dump_path = params["dump_path"] self.student = student self.teacher = teacher self.device = device self.dataloader = dataloader self.temperature = params["temperature"] assert self.temperature > 0.0 self.alpha_ce = params["alpha_ce"] self.alpha_mlm = params["alpha_mlm"] self.alpha_mse = params["alpha_mse"] self.alpha_cos = params["alpha_cos"] self.mlm_mask_prop = params["mlm_mask_prop"] assert 0.0 <= self.mlm_mask_prop <= 1.0 self.epoch = 0 self.n_iter = 0 self.n_total_iter = 0 self.n_sequences_epoch = 0 self.total_loss_epoch = 0 self.last_loss = 0 self.last_loss_ce = 0 self.last_loss_mlm = 0 if self.alpha_mse > 0.0: self.last_loss_mse = 0 if self.alpha_cos > 0.0: self.last_loss_cos = 0 self.last_log = 0 self.ce_loss_fct = nn.KLDivLoss(reduction="batchmean") self.lm_loss_fct = nn.CrossEntropyLoss(ignore_index=-100) if self.alpha_mse > 0.0: self.mse_loss_fct = nn.MSELoss(reduction="sum") if self.alpha_cos > 0.0: self.cosine_loss_fct = nn.CosineEmbeddingLoss(reduction="mean") # Initializing model optimizer assert params["gradient_accumulation_steps"] >= 1 self.num_steps_epoch = len(self.dataloader) num_train_optimization_steps = ( int(self.num_steps_epoch / params["gradient_accumulation_steps"] * params["n_epoch"]) + 1) no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ { "params": [ p for n, p in student.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad ], "weight_decay": params["weight_decay"], }, { "params": [ p for n, p in student.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad ], "weight_decay": 0.0, }, ] self.optimizer = AdamW( optimizer_grouped_parameters, lr=params["learning_rate"], eps=params["adam_epsilon"], betas=(0.9, 0.98), ) warmup_steps = math.ceil(num_train_optimization_steps * params["warmup_prop"]) self.scheduler = get_linear_schedule_with_warmup( self.optimizer, num_warmup_steps=warmup_steps, num_training_steps=num_train_optimization_steps, ) def train(self): """ The real training loop. """ self.student.train() self.teacher.eval() for _ in range(self.params["n_epoch"]): iter_bar = tqdm(self.dataloader, desc="-Iter") for batch in iter_bar: # batch = tuple(t.to(device) for t in batch) b_input_ids = batch["input_ids"].to(self.device) b_labels = batch["labels"].to(self.device) b_bool_attn_mask = batch["input_ids"] != 0 b_bool_attn_mask.to(self.device) self.step( input_ids=b_input_ids, attention_mask=b_bool_attn_mask, lm_labels=b_labels, ) iter_bar.update() iter_bar.close() self.end_epoch() self.save_checkpoint(checkpoint_name="pytorch_model.bin") print("Training is finished") def step(self, input_ids, attention_mask, lm_labels): s_output = self.student( input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, ) # (bs, seq_length, voc_size) s_logits, s_hidden_states = s_output["logits"], s_output[ "hidden_states"] with torch.no_grad(): t_output = self.teacher( input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, ) t_logits, t_hidden_states = t_output["logits"], t_output[ "hidden_states"] assert s_logits.size() == t_logits.size() mask = ((lm_labels > -1).unsqueeze(-1).expand_as(s_logits) ) # (bs, seq_length, voc_size) # or mask = attention_mask.unsqueeze(-1).expand_as(s_logits) # (bs, seq_length, voc_size) s_logits_slct = torch.masked_select( s_logits, mask) # (bs * seq_length * voc_size) modulo the 1s in mask s_logits_slct = s_logits_slct.view(-1, s_logits.size( -1)) # (bs * seq_length, voc_size) modulo the 1s in mask t_logits_slct = torch.masked_select( t_logits, mask) # (bs * seq_length * voc_size) modulo the 1s in mask t_logits_slct = t_logits_slct.view(-1, s_logits.size( -1)) # (bs * seq_length, voc_size) modulo the 1s in mask assert t_logits_slct.size() == s_logits_slct.size() loss_ce = (self.ce_loss_fct( F.log_softmax(s_logits_slct / self.temperature, dim=-1), F.softmax(t_logits_slct / self.temperature, dim=-1), ) * (self.temperature)**2) loss = self.alpha_ce * loss_ce if self.alpha_mlm > 0.0: loss_mlm = self.lm_loss_fct(s_logits.view(-1, s_logits.size(-1)), lm_labels.view(-1)) loss += self.alpha_mlm * loss_mlm if self.alpha_mse > 0.0: loss_mse = self.mse_loss_fct( s_logits_slct, t_logits_slct) / s_logits_slct.size( 0) # Reproducing batchmean reduction loss += self.alpha_mse * loss_mse if self.alpha_cos > 0.0: s_hidden_states = s_hidden_states[-1] # (bs, seq_length, dim) t_hidden_states = t_hidden_states[-1] # (bs, seq_length, dim) mask = attention_mask.unsqueeze(-1).expand_as( s_hidden_states) # (bs, seq_length, dim) assert s_hidden_states.size() == t_hidden_states.size() dim = s_hidden_states.size(-1) s_hidden_states_slct = torch.masked_select( s_hidden_states, mask) # (bs * seq_length * dim) s_hidden_states_slct = s_hidden_states_slct.view( -1, dim) # (bs * seq_length, dim) t_hidden_states_slct = torch.masked_select( t_hidden_states, mask) # (bs * seq_length * dim) t_hidden_states_slct = t_hidden_states_slct.view( -1, dim) # (bs * seq_length, dim) target = s_hidden_states_slct.new( s_hidden_states_slct.size(0)).fill_(1) # (bs * seq_length,) loss_cos = self.cosine_loss_fct(s_hidden_states_slct, t_hidden_states_slct, target) loss += self.alpha_cos * loss_cos self.total_loss_epoch += loss.item() self.last_loss = loss.item() self.last_loss_ce = loss_ce.item() if self.alpha_mlm > 0.0: self.last_loss_mlm = loss_mlm.item() if self.alpha_mse > 0.0: self.last_loss_mse = loss_mse.item() if self.alpha_cos > 0.0: self.last_loss_cos = loss_cos.item() self.optimize(loss) self.n_sequences_epoch += input_ids.size(0) def optimize(self, loss): """ Normalization on the loss (gradient accumulation or distributed training), followed by backward pass on the loss, possibly followed by a parameter update (depending on the gradient accumulation). Also update the metrics for tensorboard. """ # Check for NaN if (loss != loss).data.any(): print("NaN detected") exit() if self.params["gradient_accumulation_steps"] > 1: loss = loss / self.params["gradient_accumulation_steps"] loss.backward() self.iter() if self.n_iter % self.params["gradient_accumulation_steps"] == 0: torch.nn.utils.clip_grad_norm_(self.student.parameters(), self.params["max_grad_norm"]) self.optimizer.step() self.optimizer.zero_grad() self.scheduler.step() def iter(self): """ Update global counts, write to tensorboard and save checkpoint. """ self.n_iter += 1 self.n_total_iter += 1 def end_epoch(self): """ Finally arrived at the end of epoch (full pass on dataset). Do some tensorboard logging and checkpoint saving. """ self.save_checkpoint(checkpoint_name=f"model_epoch_{self.epoch}.pth") self.epoch += 1 self.n_sequences_epoch = 0 self.n_iter = 0 self.total_loss_epoch = 0 def save_checkpoint(self, checkpoint_name: str = "checkpoint.pth"): """ Save the current state. Only by the master process. """ mdl_to_save = (self.student.module if hasattr(self.student, "module") else self.student) mdl_to_save.config.save_pretrained(self.dump_path) state_dict = mdl_to_save.state_dict() torch.save(state_dict, os.path.join(self.dump_path, checkpoint_name))
class TD3Agent(AgentBase): """ Twin Delayed Deep Deterministic (TD3) Policy Gradient. In short, it's a slightly modified/improved version of the DDPG. Compared to the DDPG in this package, which uses Guassian noise, this TD3 uses Ornstein–Uhlenbeck process as the noise. """ name = "TD3" def __init__(self, state_size: int, action_size: int, noise_scale: float = 0.2, noise_sigma: float = 0.1, **kwargs): """ Parameters: state_size (int): Number of input dimensions. action_size (int): Number of output dimensions noise_scale (float): Added noise amplitude. Default: 0.2. noise_sigma (float): Added noise variance. Default: 0.1. Keyword parameters: hidden_layers (tuple of ints): Tuple defining hidden dimensions in fully connected nets. Default: (128, 128). actor_lr (float): Learning rate for the actor (policy). Default: 0.003. critic_lr (float): Learning rate for the critic (value function). Default: 0.003. gamma (float): Discount value. Default: 0.99. tau (float): Soft-copy factor. Default: 0.02. actor_hidden_layers (tuple of ints): Shape of network for actor. Default: `hideen_layers`. critic_hidden_layers (tuple of ints): Shape of network for critic. Default: `hideen_layers`. max_grad_norm_actor (float) Maximum norm value for actor gradient. Default: 100. max_grad_norm_critic (float): Maximum norm value for critic gradient. Default: 100. batch_size (int): Number of samples used in learning. Default: 64. buffer_size (int): Maximum number of samples to store. Default: 1e6. warm_up (int): Number of samples to observe before starting any learning step. Default: 0. update_freq (int): Number of steps between each learning step. Default 1. number_updates (int): How many times to use learning step in the learning phase. Default: 1. action_min (float): Minimum returned action value. Default: -1. action_max (float): Maximum returned action value. Default: 1. action_scale (float): Multipler value for action. Default: 1. """ super().__init__(**kwargs) self.device = self._register_param( kwargs, "device", DEVICE) # Default device is CUDA if available # Reason sequence initiation. self.state_size = state_size self.action_size = action_size hidden_layers = to_numbers_seq( self._register_param(kwargs, 'hidden_layers', (128, 128))) self.actor = ActorBody(state_size, action_size, hidden_layers=hidden_layers).to(self.device) self.critic = DoubleCritic(state_size, action_size, CriticBody, hidden_layers=hidden_layers).to(self.device) self.target_actor = ActorBody(state_size, action_size, hidden_layers=hidden_layers).to( self.device) self.target_critic = DoubleCritic(state_size, action_size, CriticBody, hidden_layers=hidden_layers).to( self.device) # Noise sequence initiation # self.noise = GaussianNoise(shape=(action_size,), mu=1e-8, sigma=noise_sigma, scale=noise_scale, device=device) self.noise = OUProcess(shape=action_size, scale=noise_scale, sigma=noise_sigma, device=self.device) # Target sequence initiation hard_update(self.target_actor, self.actor) hard_update(self.target_critic, self.critic) # Optimization sequence initiation. actor_lr = float(self._register_param(kwargs, 'actor_lr', 3e-3)) critic_lr = float(self._register_param(kwargs, 'critic_lr', 3e-3)) self.actor_optimizer = AdamW(self.actor.parameters(), lr=actor_lr) self.critic_optimizer = AdamW(self.critic.parameters(), lr=critic_lr) self.max_grad_norm_actor: float = float( kwargs.get("max_grad_norm_actor", 100)) self.max_grad_norm_critic: float = float( kwargs.get("max_grad_norm_critic", 100)) self.action_min = float(self._register_param(kwargs, 'action_min', -1.)) self.action_max = float(self._register_param(kwargs, 'action_max', 1.)) self.action_scale = float( self._register_param(kwargs, 'action_scale', 1.)) self.gamma = float(self._register_param(kwargs, 'gamma', 0.99)) self.tau = float(self._register_param(kwargs, 'tau', 0.02)) self.batch_size = int(self._register_param(kwargs, 'batch_size', 64)) self.buffer_size = int( self._register_param(kwargs, 'buffer_size', int(1e6))) self.buffer = ReplayBuffer(self.batch_size, self.buffer_size) self.warm_up = int(self._register_param(kwargs, 'warm_up', 0)) self.update_freq = int(self._register_param(kwargs, 'update_freq', 1)) self.update_policy_freq = int( self._register_param(kwargs, 'update_policy_freq', 1)) self.number_updates = int( self._register_param(kwargs, 'number_updates', 1)) self.noise_reset_freq = int( self._register_param(kwargs, 'noise_reset_freq', 10000)) # Breath, my child. self.reset_agent() self.iteration = 0 self._loss_actor = 0. self._loss_critic = 0. @property def loss(self) -> Dict[str, float]: return {'actor': self._loss_actor, 'critic': self._loss_critic} @loss.setter def loss(self, value): if isinstance(value, dict): self._loss_actor = value['actor'] self._loss_critic = value['critic'] else: self._loss_actor = value self._loss_critic = value def reset_agent(self) -> None: self.actor.reset_parameters() self.critic.reset_parameters() self.target_actor.reset_parameters() self.target_critic.reset_parameters() def act(self, state, epsilon: float = 0.0, training_mode=True) -> List[float]: """ Agent acting on observations. When the training_mode is True (default) a noise is added to each action. """ # Epsilon greedy if self._rng.random() < epsilon: rnd_actions = torch.rand(self.action_size) * ( self.action_max - self.action_min) - self.action_min return rnd_actions.tolist() with torch.no_grad(): state = to_tensor(state).float().to(self.device) action = self.actor(state) if training_mode: action += self.noise.sample() return (self.action_scale * torch.clamp(action, self.action_min, self.action_max)).tolist() def target_act(self, staten, noise: float = 0.0): with torch.no_grad(): staten = to_tensor(staten).float().to(self.device) action = self.target_actor(staten) + noise * self.noise.sample() return torch.clamp(action, self.action_min, self.action_max).cpu().numpy().astype( np.float32) def step(self, state, action, reward, next_state, done): self.iteration += 1 self.buffer.add(state=state, action=action, reward=reward, next_state=next_state, done=done) if (self.iteration % self.noise_reset_freq) == 0: self.noise.reset_states() if self.iteration < self.warm_up: return if len(self.buffer) <= self.batch_size: return if not (self.iteration % self.update_freq) or not ( self.iteration % self.update_policy_freq): for _ in range(self.number_updates): # Note: Inside this there's a delayed policy update. # Every `update_policy_freq` it will learn `number_updates` times. self.learn(self.buffer.sample()) def learn(self, experiences): """Update critics and actors""" rewards = to_tensor(experiences['reward']).float().to( self.device).unsqueeze(1) dones = to_tensor(experiences['done']).type(torch.int).to( self.device).unsqueeze(1) states = to_tensor(experiences['state']).float().to(self.device) actions = to_tensor(experiences['action']).to(self.device) next_states = to_tensor(experiences['next_state']).float().to( self.device) if (self.iteration % self.update_freq) == 0: self._update_value_function(states, actions, rewards, next_states, dones) if (self.iteration % self.update_policy_freq) == 0: self._update_policy(states) soft_update(self.target_actor, self.actor, self.tau) soft_update(self.target_critic, self.critic, self.tau) def _update_value_function(self, states, actions, rewards, next_states, dones): # critic loss next_actions = self.target_actor.act(next_states) Q_target_next = torch.min( *self.target_critic.act(next_states, next_actions)) Q_target = rewards + (self.gamma * Q_target_next * (1 - dones)) Q1_expected, Q2_expected = self.critic(states, actions) loss_critic = mse_loss(Q1_expected, Q_target) + mse_loss( Q2_expected, Q_target) # Minimize the loss self.critic_optimizer.zero_grad() loss_critic.backward() nn.utils.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm_critic) self.critic_optimizer.step() self._loss_critic = float(loss_critic.item()) def _update_policy(self, states): # Compute actor loss pred_actions = self.actor(states) loss_actor = -self.critic(states, pred_actions)[0].mean() self.actor_optimizer.zero_grad() loss_actor.backward() nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm_actor) self.actor_optimizer.step() self._loss_actor = loss_actor.item() def state_dict(self) -> Dict[str, dict]: """Describes agent's networks. Returns: state: (dict) Provides actors and critics states. """ return { "actor": self.actor.state_dict(), "target_actor": self.target_actor.state_dict(), "critic": self.critic.state_dict(), "target_critic": self.target_critic() } def log_metrics(self, data_logger: DataLogger, step: int, full_log: bool = False): data_logger.log_value("loss/actor", self._loss_actor, step) data_logger.log_value("loss/critic", self._loss_critic, step) def get_state(self): return dict( actor=self.actor.state_dict(), target_actor=self.target_actor.state_dict(), critic=self.critic.state_dict(), target_critic=self.target_critic.state_dict(), config=self._config, ) def save_state(self, path: str): agent_state = self.get_state() torch.save(agent_state, path) def load_state(self, path: str): agent_state = torch.load(path) self._config = agent_state.get('config', {}) self.__dict__.update(**self._config) self.actor.load_state_dict(agent_state['actor']) self.critic.load_state_dict(agent_state['critic']) self.target_actor.load_state_dict(agent_state['target_actor']) self.target_critic.load_state_dict(agent_state['target_critic'])
def train(args): # torch.multiprocessing.set_sharing_strategy('file_system') # too many barriers / one node data parallel and multiple node DDP os.environ['MASTER_ADDR'] = args["master_addr"] os.environ['MASTER_PORT'] = args["master_port"] os.environ['TOKENIZERS_PARALLELISM'] = "true" torch.backends.cudnn.benchmark = True rank = args["nr"] gpus = args["gpus_per_node"] if args["cpu"]: assert args["world_size"] == 1 device = torch.device("cpu") barrier = get_barrier(False) else: dist.init_process_group(args["dist_backend"], rank=rank, world_size=args["world_size"]) device = torch.device('cuda:0') # Unique only on individual node. torch.cuda.set_device(device) barrier = get_barrier(True) set_seeds(args["seed"]) mconf = model_config.to_dict() config = dict(md_config=md_config, sm_config=sm_config)[mconf.pop("model_size")] tokenizer = get_tokenizer(mconf.pop("tokenizer_name")) config.vocab_size = len(tokenizer) + 22 config.tokenizer_length = 1024 config.tokenizer_length = config.tokenizer_length - config.num_highway_cls_tokens config.max_position_embeddings = config.max_position_embeddings + config.num_highway_cls_tokens collate_fn = get_collate_fn(config.num_highway_cls_tokens, tokenizer.pad_token_id) model = FastFormerForFusedELECTRAPretraining(config, tokenizer=tokenizer, **mconf).to(device) print("Trainable Params = %s" % (numel(model) / 1_000_000)) if args["pretrained_model"] is not None: model.load_state_dict(torch.load(args["pretrained_model"], map_location={'cuda:%d' % 0: 'cuda:%d' % 0})) model.data_parallel = True # Take model to local rank if args["cpu"]: ddp_model = model else: if torch.cuda.device_count() > 1: model = nn.DataParallel(model) ddp_model = DDP(model, device_ids=[0], find_unused_parameters=True) all_params = list(filter(lambda p: p.requires_grad, ddp_model.parameters())) optc = optimizer_config.to_dict() optimizer = AdamW(all_params, lr=optc["lr"], eps=optc["eps"], weight_decay=optc["weight_decay"], betas=(optc["beta_1"], optc["beta_2"])) optimizer.zero_grad() scaler = GradScaler() model_save_dir = args["model_save_dir"] model_save_name = args["model_save_name"] if rank == 0: if not os.path.exists(model_save_dir): os.makedirs(model_save_dir) assert os.path.exists(model_save_dir) barrier() print("Optimizer Created for Rank = %s" % rank) shuffle_dataset = args["shuffle_dataset"] sampling_fraction = optc["sampling_fraction"] if not args["validate_only"] and not args["test_only"]: train_loader = build_dataloader(args["train_dataset"], shuffle_dataset, sampling_fraction, config, collate_fn, tokenizer, world_size=args["world_size"], num_workers=args["num_workers"]) print("Data Loaded for Rank = %s" % rank) validate_every_steps = args["validate_every_steps"] log_every_steps = args["log_every_steps"] save_every_steps = args["save_every_steps"] scheduler = optimization.get_constant_schedule_with_warmup(optimizer, optc["warmup_steps"]) gradient_clipping = optc["gradient_clipping"] _ = model.train() barrier() start_time = time.time() batch_times = [] model_times = [] full_times = [] print("Start Training for Rank = %s" % rank) for step, batch in enumerate(train_loader): model.zero_grad() optimizer.zero_grad() if step == 0: print("First Batch Training for Rank = %s" % rank) # if step <= 39: # continue gen_batch_time = time.time() - start_time batch_times.append(gen_batch_time) if (step + 1) % save_every_steps == 0: if rank == 0: torch.save(ddp_model.state_dict(), os.path.join(model_save_dir, model_save_name)) barrier() if (step + 1) % validate_every_steps == 0: if rank == 0: val_results = LargeValidator(args["validation_dataset"], ddp_model, config, device, tokenizer)() print("Rank = %s, steps = %s, Val = %s" % (rank, step, val_results)) barrier() record_accuracy = False if (step + 1) % log_every_steps == 0: record_accuracy = True batch["record_accuracy"] = record_accuracy labels = batch["label_mlm_input_ids"] if "label_mlm_input_ids" in batch else batch["input_ids"] labels = labels.to(device) model_start_time = time.time() if args["cpu"]: output = ddp_model(**batch, labels=labels) output = {key: [item[key] for item in output] for key in list(functools.reduce( lambda x, y: x.union(y), (set(dicts.keys()) for dicts in output) )) } output = {k: torch.mean(v) for k, v in output.items()} loss = output["loss"] loss_dict = output["loss_dict"] loss.backward() torch.nn.utils.clip_grad_norm_(all_params, gradient_clipping) optimizer.step() scheduler.step() optimizer.zero_grad() else: with autocast(): output = ddp_model(**batch, labels=labels) output = {key: [item[key] for item in output] for key in list(functools.reduce( lambda x, y: x.union(y), (set(dicts.keys()) for dicts in output) )) } output = {k: torch.mean(v) for k, v in output.items()} loss = output["loss"] loss_dict = output["loss_dict"] scaler.scale(loss).backward() scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(all_params, gradient_clipping) scaler.step(optimizer) scaler.update() scheduler.step() optimizer.zero_grad() model_end_time = time.time() - model_start_time model_times.append(model_end_time) full_time = time.time() - start_time full_times.append(full_time) start_time = time.time() if (step + 1) % log_every_steps == 0: print("Rank = %s, steps = %s, batch_size = %s, Loss = %s, Accuracy = %s" % (rank, step, batch["input_ids"].size(), loss_dict, output["accuracy_hist"])) print("Batch time = %s, Model Time = %s, Full time = %s" % (np.mean(batch_times), np.mean(model_times), np.mean(full_times))) batch_times = [] model_times = [] full_times = [] clean_memory() barrier() # Take inputs to local_rank # TODO: validate on multigpu, sort the val datasets alphabetically and let the gpu with rank == dataset rank in sort pick up the dataset. GPUs with rank > len(datasetDict) stay idle. # TODO: select one dataset and make full batch from it, this way rebalancing can be easy. # TODO: dataset rebalancing. # TODO: save model only in local_rank == 0 process # TODO: Check if all initialised model weights are same?? # I've been tracking an ema of sample training loss during training and using that to guide weighted data sampling (rather than the typical uniform sampling). Seems to help with a variety of real world datasets where the bulk of the data is often very similar and easy to learn but certain subpopulations are much more challenging. pass
def train(args, train_dataset, model): """Train the model on `steps` batches""" logger.debug('start') args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) print('train_batch_size %d' % args.train_batch_size) train_sampler = RandomSampler(train_dataset) train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) # Prepare optimizer and schedule (linear warmup and decay) # 不需要权重衰减的参数 no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] bert_param_optimizer = list(model.bert.named_parameters()) crf_param_optimizer = list(model.crf.named_parameters()) optimizer_grouped_parameters = [ { 'params': [ p for n, p in bert_param_optimizer if not any(nd in n for nd in no_decay) ], 'weight_decay': args.weight_decay, 'lr': args.bert_lr }, { 'params': [ p for n, p in bert_param_optimizer if any(nd in n for nd in no_decay) ], 'weight_decay': 0.0, 'lr': args.bert_lr }, { 'params': [ p for n, p in crf_param_optimizer if not any(nd in n for nd in no_decay) ], 'weight_decay': args.weight_decay, 'lr': args.crf_lr }, { 'params': [ p for n, p in crf_param_optimizer if any(nd in n for nd in no_decay) ], 'weight_decay': 0.0, 'lr': args.crf_lr }, ] optimizer = AdamW(optimizer_grouped_parameters) # args.warmup_steps = int(t_total * args.warmup_proportion) # scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, # num_training_steps=t_total) # scheduler.step() # Train! logger.info("***** Running training *****") global_step = 0 for epoch in range(int(args.num_train_epochs)): for step, batch_data in enumerate(train_dataloader): # set model to training mode model.train() batch_data = tuple(t.to(args.device) for t in batch_data) batch_input_ids, batch_input_mask, batch_segment_ids, batch_label_ids = batch_data optimizer.zero_grad() outputs = model(input_ids=batch_input_ids, attention_mask=batch_input_mask, token_type_ids=batch_segment_ids, labels=batch_label_ids) loss, scores = outputs[:2] loss.backward() optimizer.step() if step % 5 == 0: print('epoch: {} | step: {} | loss: {}'.format( epoch, step, loss.item())) global_step += 1 torch.save(model.state_dict(), args.modelfile_finetuned) return global_step
def train(): """ Train the model using the parameters defined in the config file """ print('Initialising ...') cfg = TrainConfig() checkpoint_folder = 'checkpoints/{}/'.format(cfg.experiment_name) if not os.path.exists(checkpoint_folder): os.makedirs(checkpoint_folder) tb_folder = 'tb/{}/'.format(cfg.experiment_name) if not os.path.exists(tb_folder): os.makedirs(tb_folder) writer = SummaryWriter(logdir=tb_folder, flush_secs=30) model = ParrotModel().cuda().train() optimiser = AdamW(model.parameters(), lr=cfg.initial_lr, weight_decay=cfg.weight_decay) train_dataset = ParrotDataset(cfg.train_labels, cfg.mp3_folder) train_loader = DataLoader(train_dataset, batch_size=cfg.batch_size, num_workers=cfg.workers, collate_fn=parrot_collate_function, pin_memory=True) val_dataset = ParrotDataset(cfg.val_labels, cfg.mp3_folder) val_loader = DataLoader(val_dataset, batch_size=cfg.batch_size, num_workers=cfg.workers, collate_fn=parrot_collate_function, shuffle=False, pin_memory=True) epochs = cfg.epochs init_loss, step = 0., 0 avg_loss = AverageMeter() print('Starting training') for epoch in range(epochs): loader_length = len(train_loader) epoch_start = time.time() for batch_idx, batch in enumerate(train_loader): optimiser.zero_grad() # VRAM control by skipping long examples if batch['spectrograms'].shape[-1] > cfg.max_time: continue # inference target = batch['targets'].cuda() model_input = batch['spectrograms'].cuda() model_output = model(model_input) # loss input_lengths = batch['input_lengths'].cuda() target_lengths = batch['target_lengths'].cuda() loss = ctc_loss(model_output, target, input_lengths, target_lengths) loss.backward() if epoch == 0 and batch_idx == 0: init_loss = loss # logging elapsed = time.time() - epoch_start progress = batch_idx / loader_length est = datetime.timedelta( seconds=int(elapsed / progress)) if progress > 0.001 else '-' avg_loss.update(loss) suffix = '\tloss {:.4f}/{:.4f}\tETA [{}/{}]'.format( avg_loss.avg, init_loss, datetime.timedelta(seconds=int(elapsed)), est) printProgressBar(batch_idx, loader_length, suffix=suffix, prefix='Epoch [{}/{}]\tStep [{}/{}]'.format( epoch, epochs, batch_idx, loader_length)) writer.add_scalar('Steps/train_loss', loss, step) # saving the model if step % cfg.checkpoint_every == 0: test_name = '{}/test_epoch{}.mp3'.format( checkpoint_folder, epoch) test_mp3_file(cfg.test_mp3, model, test_name) checkpoint_name = '{}/epoch_{}.pth'.format( checkpoint_folder, epoch) torch.save( { 'model': model.state_dict(), 'epoch': epoch, 'batch_idx': loader_length, 'step': step, 'optimiser': optimiser.state_dict() }, checkpoint_name) # validating if step % cfg.val_every == 0: val(model, val_loader, writer, step) model = model.train() step += 1 optimiser.step() # end of epoch print('') writer.add_scalar('Epochs/train_loss', avg_loss.avg, epoch) avg_loss.reset() test_name = '{}/test_epoch{}.mp3'.format(checkpoint_folder, epoch) test_mp3_file(cfg.test_mp3, model, test_name) checkpoint_name = '{}/epoch_{}.pth'.format(checkpoint_folder, epoch) torch.save( { 'model': model.state_dict(), 'epoch': epoch, 'batch_idx': loader_length, 'step': step, 'optimiser': optimiser.state_dict() }, checkpoint_name) # finished training writer.close() print('Training finished :)')
class Trainer : def __init__(self , eval_df , train_df, max_length , batch_size , n_class , name_model ) : self.model = sentiment_analysis(n_class ,name_model ) self.tokenizer = self.model.tokenizer self.eval_df = eval_df self.train_df = train_df self.max_length = max_length self.batch_size = batch_size self.eval_dataloader = create_data_loader(self.eval_df, self.tokenizer, max_len = self.max_length , batch_size = self.batch_size) self.train_dataloader = create_data_loader(self.train_df, self.tokenizer, max_len = self.max_length , batch_size = self.batch_size) self.optimizer = AdamW(self.model.parameters(), lr=2e-5) self.loss_fn = nn.CrossEntropyLoss() def train_epoch(self) : losses = [] train_correct_predictions = [] self.model.train() for data in self.train_dataloader : input_ids = data['input_ids'] attention_mask = data['attention_mask'] targets = data['targets'] outputs = self.model(input_ids = input_ids , attention_mask = attention_mask ) preds = torch.max(outputs , dim = -1)[1] #Calculate metrics loss = self.loss_fn(outputs , targets) losses.append(loss) for i in range(len(preds)) : if preds[i] == targets[i] : train_correct_predictions.append(1) else : train_correct_predictions.append(0) loss.backward() nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) self.optimizer.step() self.scheduler.step() self.optimizer.zero_grad() return sum(train_correct_predictions) / float(len(train_correct_predictions)) , sum(losses) / float(len(losses)) def eval_model(self) : self.model.eval() losses_eval = [] correct_predicitons = [] with torch.no_grad(): for d in self.eval_dataloader : input_ids = d['input_ids'] attention_mask = d['attention_mask'] targets = d['targets'] outputs = self.model( input_ids = input_ids , attention_mask = attention_mask) _ , preds = torch.max(outputs , dim = 1) loss = self.loss_fn(outputs , targets ) losses_eval.append(loss.item()) for i in range(len(preds)) : if preds[i] == targets[i] : correct_predicitons.append(1) else : correct_predicitons.append(0) return sum(correct_predicitons) / float(len(correct_predicitons)) , sum(losses_eval) / float(len(losses_eval)) def train (self , EPOCHS) : best_accuracy = 0 history = defaultdict(list) total_steps = len(self.train_dataloader) * EPOCHS print(total_steps) self.scheduler = get_linear_schedule_with_warmup(self.optimizer , num_warmup_steps=0, num_training_steps=total_steps) for number_epochs in range(EPOCHS) : train_acc, train_loss = self.train_epoch () print(f'Train loss {train_loss} accuracy {train_acc}') val_acc , val_loss = self.eval_model() print(f'eval loss {val_loss} accuracy {val_acc}') history['train_acc'].append(train_acc) history['train_loss'].append(train_loss) history['val_acc'].append(val_acc) history['val_loss'].append(val_loss) if val_acc > best_accuracy: self.model.save('best_model_state') best_accuracy = val_acc
# x_df = x_df.style.background_gradient(cmap='Greys', axis=None, subset=slice(0,10)) placeholders_[0][0].write(x_df) y_df = pd.DataFrame(data=y.detach().numpy()) y_df = y_df.style.background_gradient(cmap='Greys', axis=None) placeholders_[1][0].write(y_df) output = net(x.flatten()).reshape((3, 4)) loss = criterion(output, y) out_df = pd.DataFrame(data=output.detach().numpy()) out_df = out_df.style.background_gradient(cmap='Greys', axis=None) placeholders_[2][0].write(out_df) print(f'Loss: {loss.detach()}') optimizer.zero_grad() loss.backward() optimizer.step() params = net.parameters() # print(list(enumerate(params))) for i, param in enumerate(params): if i == 0: p_df = pd.DataFrame(data=param.reshape(-1, 3).detach().numpy()) p_df = p_df.style.background_gradient(cmap='Greys', axis=None) placeholders[i][0].write(p_df) else: placeholders[i][0].write(param.detach().numpy()) # print(params) # exit() # placeholders[0][0].write()
def train(self, model, epochs_num=1, train_dataset=None, validation_dataset=None, data_collator=None, parent_information=None, lr=0.01, batch_size=64, weight_decay=0.01, betas=(0.9, 0.999), evaluate_steps=40, has_parent=True, verbose=False): ''' Train the model given with the dataset provided. Will run evaluation on the validation set every `evaluate_steps` training steps, and at the end of each epoch. Args: model: instantiated model to train epochs_num: Number of epochs to train train_dataset: Train dataset validation_dataset: Validation dataset data_collator: A data collator function that when called will collate the data, passed to Dataloader parent_information: lr: Learning rate to use in the Opimizer batch_size: Batch size to use weight_decay: Optimizer wieght decay betas: Betas used in the Optimizer evaluate_steps: How many training steps verbose: If true the training loss and addition f1 scores will be printed every step Returns: f1: double, the resulting mean f1 score of all the labels (it will be a number between 0 and 1) precision: double, the resulting mean precision of all the labels (it will be a number between 0 and 1) recall: ''' self.model = model # Prints additional loss and metrics information during training if set to true self.verbose = verbose # Set timers start = time.time() remaining_time = 0 # Get dataloader self.data_collator = data_collator train_dataloader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=self.data_collator, shuffle=True) # Default optimizer optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay, betas=betas) mb = master_bar(range(epochs_num)) pb = progress_bar(train_dataloader, parent=mb) for epoch in mb: for i_batch, sample_batched in enumerate(pb): self.model.train() # Get input x = sample_batched[0].to(self.device) #if i_batch == 0: #print() #print(x.size()) # Get targets (labels) target = sample_batched[1].float().to(self.device) if has_parent and (len(sample_batched) == 3): parent_labels = sample_batched[2].float().to(self.device) if self.device == 'cuda': self.model.cuda(0) x = x.cuda(0) parent_labels = parent_labels.cuda(0) target = target.cuda(0) else: self.model.cpu() x = x.cpu() parent_labels = parent_labels.cpu() target = target.cpu() # Pass input to model output = self.model(x, parent_labels) else: if self.device == 'cuda': self.model.cuda(0) x = x.cuda(0) target = target.cuda(0) else: self.model.cpu() x = x.cpu() target = target.cpu() # Pass input to model output = self.model(x) # Loss train_loss = self.criterion(output, target) if self.verbose: print(f'train_loss: {train_loss}') # Do backward, do step and zero gradients train_loss.backward() optimizer.step() model.zero_grad() optimizer.zero_grad() # Evaluate if (i_batch > 0) and (i_batch % evaluate_steps) == 0: #print('\nevaluating...') _ = self.evaluate(self.model, validation_dataset) self.train_losses.append(train_loss.item()) # Run evaluation at the end of each epoch and return validation outputs #print('\nEnd of epoch evaluation results:') validation_outputs = self.evaluate(model, validation_dataset) y_hat_validation, validation_labels_child, validation_labels_parent = validation_outputs # Print out progress stats end = time.time() remaining_time = remaining_time * 0.90 + ( (end - start) * (epochs_num - epoch + 1) / (epoch + 1)) * 0.1 remaining_time_corrected = remaining_time / (1 - (0.9**(epoch + 1))) epoch_str = "last epoch finished: " + str(epoch + 1) progress_str = "progress: " + str( (epoch + 1) * 100 / epochs_num) + "%" time_str = "time: " + str(remaining_time_corrected / 60) + " mins" sys.stdout.write("\r" + epoch_str + " -- " + progress_str + " -- " + time_str) sys.stdout.flush() self.epochs.append(epoch) print("\n" + "Training completed. Total training time: " + str(round((end - start) / 60, 2)) + " mins") return (y_hat_validation, validation_labels_child, validation_labels_parent, self.train_losses, self.validation_losses, self.f1_scores_validations, self.precisions_validations, self.recalls_validations)