def train(self, train, dev, test, buckets=32, batch_size=5000, lr=2e-3, mu=.9, nu=.9, epsilon=1e-12, clip=5.0, decay=.75, decay_steps=5000, epochs=5000, patience=100, weight_decay=0, verbose=True, **kwargs): args = self.args.update(locals()) init_logger(logger, verbose=args.verbose) self.transform.train() if dist.is_initialized(): args.batch_size = args.batch_size // dist.get_world_size() logger.info("Loading the data") train = Dataset(self.transform, args.train, **args) dev = Dataset(self.transform, args.dev) test = Dataset(self.transform, args.test) train.build(args.batch_size, args.buckets, True, dist.is_initialized()) dev.build(args.batch_size, args.buckets) test.build(args.batch_size, args.buckets) logger.info(f"\n{'train:':6} {train}\n{'dev:':6} {dev}\n{'test:':6} {test}\n") logger.info(f"{self.model}\n") if dist.is_initialized(): self.model = DDP(self.model, device_ids=[args.local_rank], find_unused_parameters=True) self.optimizer = Adam(self.model.parameters(), args.lr, (args.mu, args.nu), args.epsilon, weight_decay=args.weight_decay) self.scheduler = ExponentialLR(self.optimizer, args.decay**(1/args.decay_steps)) elapsed = timedelta() best_e, best_metric = 1, Metric() for epoch in range(1, args.epochs + 1): start = datetime.now() logger.info(f"Epoch {epoch} / {args.epochs}:") self._train(train.loader) loss, dev_metric = self._evaluate(dev.loader) logger.info(f"{'dev:':6} - loss: {loss:.4f} - {dev_metric}") loss, test_metric = self._evaluate(test.loader) logger.info(f"{'test:':6} - loss: {loss:.4f} - {test_metric}") t = datetime.now() - start # save the model if it is the best so far if dev_metric > best_metric: best_e, best_metric = epoch, dev_metric if is_master(): self.save(args.path) logger.info(f"{t}s elapsed (saved)\n") else: logger.info(f"{t}s elapsed\n") elapsed += t if epoch - best_e >= args.patience: break loss, metric = self.load(**args)._evaluate(test.loader) logger.info(f"Epoch {best_e} saved") logger.info(f"{'dev:':6} - {best_metric}") logger.info(f"{'test:':6} - {metric}") logger.info(f"{elapsed}s elapsed, {elapsed / epoch}s/epoch")
def train(self, train, dev, test, buckets=32, batch_size=5000, clip=5.0, epochs=5000, patience=100, **kwargs): args = self.args.update(locals()) init_logger(logger, verbose=args.verbose) self.transform.train() if dist.is_initialized(): args.batch_size = args.batch_size // dist.get_world_size() logger.info("Loading the data") train = Dataset(self.transform, args.train, **args) dev = Dataset(self.transform, args.dev) test = Dataset(self.transform, args.test) train.build(args.batch_size, args.buckets, True, dist.is_initialized()) dev.build(args.batch_size, args.buckets) test.build(args.batch_size, args.buckets) logger.info( f"\n{'train:':6} {train}\n{'dev:':6} {dev}\n{'test:':6} {test}\n") if dist.is_initialized(): self.model = DDP(self.model, device_ids=[args.local_rank], find_unused_parameters=True) elapsed = timedelta() best_e, best_metric = 1, Metric() for epoch in range(1, args.epochs + 1): start = datetime.now() logger.info(f"Epoch {epoch} / {args.epochs}:") #if epoch < 2: # self._train(train.loader) #else: #print('Using margin loss') self._train(train.loader, loss_type='margin') loss, dev_metric = self._evaluate(dev.loader) logger.info(f"{'dev:':5} loss: {loss:.4f} - {dev_metric}") loss, test_metric = self._evaluate(test.loader) logger.info(f"{'test:':5} loss: {loss:.4f} - {test_metric}") t = datetime.now() - start # save the model if it is the best so far if dev_metric > best_metric: best_e, best_metric = epoch, dev_metric if is_master(): self.save(args.path) logger.info(f"{t}s elapsed (saved)\n") else: logger.info(f"{t}s elapsed\n") elapsed += t if epoch - best_e >= args.patience: break loss, metric = self.load(**args)._evaluate(test.loader) logger.info(f"Epoch {best_e} saved") logger.info(f"{'dev:':5} {best_metric}") logger.info(f"{'test:':5} {metric}") logger.info(f"{elapsed}s elapsed, {elapsed / epoch}s/epoch")
def train(self, train, dev, test, buckets=32, batch_size=5000, update_steps=1, clip=5.0, epochs=5000, patience=100, **kwargs): args = self.args.update(locals()) init_logger(logger, verbose=args.verbose) self.transform.train() if dist.is_initialized(): args.batch_size = args.batch_size // dist.get_world_size() logger.info("Loading the data") train = Dataset(self.transform, args.train, **args) dev = Dataset(self.transform, args.dev) test = Dataset(self.transform, args.test) train.build(args.batch_size // args.update_steps, args.buckets, True, dist.is_initialized()) dev.build(args.batch_size, args.buckets) test.build(args.batch_size, args.buckets) logger.info( f"\n{'train:':6} {train}\n{'dev:':6} {dev}\n{'test:':6} {test}\n") if args.encoder == 'lstm': self.optimizer = Adam(self.model.parameters(), args.lr, (args.mu, args.nu), args.eps, args.weight_decay) self.scheduler = ExponentialLR(self.optimizer, args.decay**(1 / args.decay_steps)) else: from transformers import AdamW, get_linear_schedule_with_warmup steps = len(train.loader) * epochs // args.update_steps self.optimizer = AdamW( [{ 'params': c.parameters(), 'lr': args.lr * (1 if n == 'encoder' else args.lr_rate) } for n, c in self.model.named_children()], args.lr) self.scheduler = get_linear_schedule_with_warmup( self.optimizer, int(steps * args.warmup), steps) if dist.is_initialized(): self.model = DDP(self.model, device_ids=[args.local_rank], find_unused_parameters=True) elapsed = timedelta() best_e, best_metric = 1, Metric() for epoch in range(1, args.epochs + 1): start = datetime.now() logger.info(f"Epoch {epoch} / {args.epochs}:") self._train(train.loader) loss, dev_metric = self._evaluate(dev.loader) logger.info(f"{'dev:':5} loss: {loss:.4f} - {dev_metric}") loss, test_metric = self._evaluate(test.loader) logger.info(f"{'test:':5} loss: {loss:.4f} - {test_metric}") t = datetime.now() - start if dev_metric > best_metric: best_e, best_metric = epoch, dev_metric if is_master(): self.save(args.path) logger.info(f"{t}s elapsed (saved)\n") else: logger.info(f"{t}s elapsed\n") elapsed += t if epoch - best_e >= args.patience: break loss, metric = self.load(**args)._evaluate(test.loader) logger.info(f"Epoch {best_e} saved") logger.info(f"{'dev:':5} {best_metric}") logger.info(f"{'test:':5} {metric}") logger.info(f"{elapsed}s elapsed, {elapsed / epoch}s/epoch")
def train(self, train, dev, test, buckets=32, batch_size=5000, update_steps=1, clip=5.0, epochs=5000, patience=100, **kwargs): args = self.args.update(locals()) init_logger(logger, verbose=args.verbose) self.transform.train() batch_size = batch_size // update_steps if dist.is_initialized(): batch_size = batch_size // dist.get_world_size() logger.info("Loading the data") train = Dataset(self.transform, args.train, **args).build(batch_size, buckets, True, dist.is_initialized()) dev = Dataset(self.transform, args.dev).build(batch_size, buckets) test = Dataset(self.transform, args.test).build(batch_size, buckets) logger.info( f"\n{'train:':6} {train}\n{'dev:':6} {dev}\n{'test:':6} {test}\n") if args.encoder == 'lstm': self.optimizer = Adam(self.model.parameters(), args.lr, (args.mu, args.nu), args.eps, args.weight_decay) self.scheduler = ExponentialLR(self.optimizer, args.decay**(1 / args.decay_steps)) else: from transformers import AdamW, get_linear_schedule_with_warmup steps = len(train.loader) * epochs // args.update_steps self.optimizer = AdamW([{ 'params': p, 'lr': args.lr * (1 if n.startswith('encoder') else args.lr_rate) } for n, p in self.model.named_parameters()], args.lr) self.scheduler = get_linear_schedule_with_warmup( self.optimizer, int(steps * args.warmup), steps) if dist.is_initialized(): self.model = DDP(self.model, device_ids=[args.local_rank], find_unused_parameters=True) self.epoch, self.best_e, self.patience, self.best_metric, self.elapsed = 1, 1, patience, Metric( ), timedelta() if self.args.checkpoint: self.optimizer.load_state_dict( self.checkpoint_state_dict.pop('optimizer_state_dict')) self.scheduler.load_state_dict( self.checkpoint_state_dict.pop('scheduler_state_dict')) set_rng_state(self.checkpoint_state_dict.pop('rng_state')) for k, v in self.checkpoint_state_dict.items(): setattr(self, k, v) train.loader.batch_sampler.epoch = self.epoch for epoch in range(self.epoch, args.epochs + 1): start = datetime.now() logger.info(f"Epoch {epoch} / {args.epochs}:") self._train(train.loader) loss, dev_metric = self._evaluate(dev.loader) logger.info(f"{'dev:':5} loss: {loss:.4f} - {dev_metric}") loss, test_metric = self._evaluate(test.loader) logger.info(f"{'test:':5} loss: {loss:.4f} - {test_metric}") t = datetime.now() - start self.epoch += 1 self.patience -= 1 self.elapsed += t if dev_metric > self.best_metric: self.best_e, self.patience, self.best_metric = epoch, patience, dev_metric if is_master(): self.save_checkpoint(args.path) logger.info(f"{t}s elapsed (saved)\n") else: logger.info(f"{t}s elapsed\n") if self.patience < 1: break parser = self.load(**args) loss, metric = parser._evaluate(test.loader) parser.save(args.path) logger.info(f"Epoch {self.best_e} saved") logger.info(f"{'dev:':5} {self.best_metric}") logger.info(f"{'test:':5} {metric}") logger.info(f"{self.elapsed}s elapsed, {self.elapsed / epoch}s/epoch")