def main(args): # Set up logging and devices args.save_dir = util.get_save_dir(args.save_dir, args.name, training=True) log = util.get_logger(args.save_dir, args.name) tbx = SummaryWriter(args.save_dir) device, args.gpu_ids = util.get_available_devices() log.info(f'Args: {dumps(vars(args), indent=4, sort_keys=True)}') args.batch_size *= max(1, len(args.gpu_ids)) # Set random seed log.info(f'Using random seed {args.seed}...') random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) # Get embeddings log.info('Loading embeddings...') word_vectors = util.torch_from_json(args.word_emb_file) char_vectors = util.torch_from_json(args.char_emb_file) # Get model log.info('Building model...') if (args.model == 'baseline'): model = Baseline(word_vectors=word_vectors, hidden_size=args.hidden_size, drop_prob=args.drop_prob) optimizer = optim.Adadelta(model.parameters(), args.lr, weight_decay=args.l2_wd) elif (args.model == 'bidaf'): model = BiDAF(word_vectors=word_vectors, char_vectors=char_vectors, char_emb_dim=args.char_emb_dim, hidden_size=args.hidden_size, drop_prob=args.drop_prob) optimizer = optim.Adadelta(model.parameters(), args.lr, weight_decay=args.l2_wd) elif (args.model == 'qanet'): model = QANet(word_vectors=word_vectors, char_vectors=char_vectors, char_emb_dim=args.char_emb_dim, hidden_size=args.hidden_size, n_conv_emb_enc=args.n_conv_emb, n_conv_mod_enc=args.n_conv_mod, drop_prob_word=0.1, drop_prob_char=0.05, kernel_size_emb_enc_block=7, kernel_size_mod_enc_block=7, n_heads=args.n_heads) optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(args.beta_1, args.beta_2), eps=args.epsilon, weight_decay=args.l2_wd) elif (args.model == 'qanet_out'): model = QANet(word_vectors=word_vectors, char_vectors=char_vectors, char_emb_dim=args.char_emb_dim, hidden_size=args.hidden_size, n_conv_emb_enc=args.n_conv_emb, n_conv_mod_enc=args.n_conv_mod, drop_prob_word=0.1, drop_prob_char=0.05, kernel_size_emb_enc_block=7, kernel_size_mod_enc_block=7, n_heads=args.n_heads) optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(args.beta_1, args.beta_2), eps=args.epsilon, weight_decay=args.l2_wd) model = nn.DataParallel(model, args.gpu_ids) if args.load_path: log.info(f'Loading checkpoint from {args.load_path}...') model, step = util.load_model(model, args.load_path, args.gpu_ids) else: step = 0 model = model.to(device) model.train() ema = util.EMA(model, args.ema_decay) # Get saver saver = util.CheckpointSaver(args.save_dir, max_checkpoints=args.max_checkpoints, metric_name=args.metric_name, maximize_metric=args.maximize_metric, log=log) # Get optimizer and scheduler scheduler = sched.LambdaLR(optimizer, lambda s: 1.) # Constant LR # Get data loader log.info('Building dataset...') train_dataset = SQuAD(args.train_record_file, args.use_squad_v2) train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=collate_fn) dev_dataset = SQuAD(args.dev_record_file, args.use_squad_v2) dev_loader = data.DataLoader(dev_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, collate_fn=collate_fn) # Train log.info('Training...') steps_till_eval = args.eval_steps epoch = step // len(train_dataset) while epoch != args.num_epochs: epoch += 1 log.info(f'Starting epoch {epoch}...') with torch.enable_grad(), \ tqdm(total=len(train_loader.dataset)) as progress_bar: for cw_idxs, cc_idxs, qw_idxs, qc_idxs, y1, y2, ids in train_loader: # Setup for forward cw_idxs = cw_idxs.to(device) qw_idxs = qw_idxs.to(device) cc_idxs = cc_idxs.to(device) qc_idxs = qc_idxs.to(device) batch_size = cw_idxs.size(0) optimizer.zero_grad() # Forward log_p1, log_p2 = model(cw_idxs, cc_idxs, qw_idxs, qc_idxs) y1, y2 = y1.to(device), y2.to(device) loss = F.nll_loss(log_p1, y1) + F.nll_loss(log_p2, y2) loss_val = loss.item() # Backward loss.backward() nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() scheduler.step(step // batch_size) ema(model, step // batch_size) # Log info step += batch_size progress_bar.update(batch_size) progress_bar.set_postfix(epoch=epoch, NLL=loss_val) tbx.add_scalar('train/NLL', loss_val, step) tbx.add_scalar('train/LR', optimizer.param_groups[0]['lr'], step) steps_till_eval -= batch_size if steps_till_eval <= 0: steps_till_eval = args.eval_steps # Evaluate and save checkpoint log.info(f'Evaluating at step {step}...') ema.assign(model) results, pred_dict = evaluate(model, dev_loader, device, args.dev_eval_file, args.max_ans_len, args.use_squad_v2) saver.save(step, model, results[args.metric_name], device) ema.resume(model) # Log to console results_str = ', '.join(f'{k}: {v:05.2f}' for k, v in results.items()) log.info(f'Dev {results_str}') # Log to TensorBoard log.info('Visualizing in TensorBoard...') for k, v in results.items(): tbx.add_scalar(f'dev/{k}', v, step) util.visualize(tbx, pred_dict=pred_dict, eval_path=args.dev_eval_file, step=step, split='dev', num_visuals=args.num_visuals)
def main(_): # Load MNIST data mnist = load_mnist() pre_training = FLAGS.pre_train # Define the deep learning model if FLAGS.model == 'Base': pre_training = False kernlen = int(FLAGS.frame_size / 2) net = Baseline(directory=FLAGS.dir, optimizer=FLAGS.optimizer, learning_rate=FLAGS.learning_rate, layer_sizes=FLAGS.arch, num_features=FLAGS.num_features, num_filters=FLAGS.num_filters, frame_size=FLAGS.frame_size) if FLAGS.model == 'Cat': kernlen = int(FLAGS.frame_size / 2) net = Cat_Net(layer_sizes=FLAGS.arch, optimizer=FLAGS.optimizer, num_filters=FLAGS.num_filters, num_features=FLAGS.num_features, num_samples=FLAGS.num_samples, frame_size=FLAGS.frame_size, num_cat=FLAGS.num_cat, learning_rate=FLAGS.learning_rate, feedback_distance=FLAGS.feedback_distance, directory=FLAGS.dir) elif FLAGS.model == 'Gumbel': kernlen = int(FLAGS.frame_size / 2) net = Gumbel_Net(layer_sizes=FLAGS.arch, optimizer=FLAGS.optimizer, num_filters=FLAGS.num_filters, num_features=FLAGS.num_features, frame_size=FLAGS.frame_size, num_cat=FLAGS.num_cat, learning_rate=FLAGS.learning_rate, feedback_distance=FLAGS.feedback_distance, directory=FLAGS.dir, second_conv=FLAGS.second_conv, initial_tau=FLAGS.initial_tau, tau_decay=FLAGS.tau_decay, reg=FLAGS.reg) elif FLAGS.model == 'RawG': pre_training = False kernlen = 60 net = Raw_Gumbel_Net(layer_sizes=FLAGS.arch, optimizer=FLAGS.optimizer, num_filters=FLAGS.num_filters, num_features=FLAGS.frame_size**2, frame_size=FLAGS.frame_size, num_cat=FLAGS.num_cat, learning_rate=FLAGS.learning_rate, feedback_distance=FLAGS.feedback_distance, directory=FLAGS.dir, second_conv=FLAGS.second_conv, initial_tau=FLAGS.initial_tau, meta=None) elif FLAGS.model == 'RL': kernlen = int(FLAGS.frame_size / 2) net = Bernoulli_Net(layer_sizes=FLAGS.arch, optimizer=FLAGS.optimizer, num_filters=FLAGS.num_filters, num_features=FLAGS.num_features, num_samples=FLAGS.num_samples, frame_size=FLAGS.frame_size, learning_rate=FLAGS.learning_rate, feedback_distance=FLAGS.feedback_distance, directory=FLAGS.dir, second_conv=FLAGS.second_conv) elif FLAGS.model == 'RawB': pre_training = True kernlen = 60 net = Raw_Bernoulli_Net(layer_sizes=FLAGS.arch, optimizer=FLAGS.optimizer, num_filters=FLAGS.num_filters, num_features=FLAGS.frame_size**2, num_samples=FLAGS.num_samples, frame_size=FLAGS.frame_size, learning_rate=FLAGS.learning_rate, feedback_distance=FLAGS.feedback_distance, directory=FLAGS.dir, second_conv=FLAGS.second_conv) X_train, train_coords = convertCluttered( mnist.train.images, finalImgSize=FLAGS.frame_size, number_patches=FLAGS.number_patches) y_train = mnist.train.labels train_coords = np.array( [gkern(coord[0], coord[1], kernlen=kernlen) for coord in train_coords]) X_test, test_coords = convertCluttered(mnist.test.images, finalImgSize=FLAGS.frame_size, number_patches=FLAGS.number_patches) # test_coords = np.array([gkern(coord[0], coord[1], kernlen=20) for coord in test_coords]) y_test = mnist.test.labels batch_size = FLAGS.batch_size if pre_training: print("Pre-training") for epoch in tqdm(range(FLAGS.epochs)): _x, _y = input_fn(X_test, y_test, batch_size=batch_size) net.evaluate(_x, _y, pre_trainining=True) X_train, train_coords = convertCluttered( mnist.train.images, finalImgSize=FLAGS.frame_size, number_patches=FLAGS.number_patches) y_train = mnist.train.labels # print(net.confusion_matrix(_x, _y)) net.save() X_train, y_train, train_coords = shuffle_in_unison( X_train, y_train, train_coords) for i in range(0, len(X_train), batch_size): _x, _y = input_fn(X_train[i:i + batch_size], y_train[i:i + batch_size], batch_size=batch_size) net.pre_train(_x, _y, dropout=0.8) print("Training") for epoch in tqdm(range(FLAGS.epochs)): X_train, y_train, train_coords = shuffle_in_unison( X_train, y_train, train_coords) _x, _y = input_fn(X_test, y_test, batch_size=batch_size) net.evaluate(_x, _y) X_train, train_coords = convertCluttered( mnist.train.images, finalImgSize=FLAGS.frame_size, number_patches=FLAGS.number_patches) y_train = mnist.train.labels # print(net.confusion_matrix(_x, _y)) net.save() for i in range(0, len(X_train), batch_size): _x, _y = X_train[i:i + batch_size], y_train[i:i + batch_size] net.train(_x, _y, dropout=FLAGS.dropout) if FLAGS.model == 'RL' or FLAGS.model == 'Gumbel' or FLAGS.model == 'Cat' or FLAGS.model == 'RawB' or FLAGS.model == 'RawG': print("Feedback Training") for epoch in tqdm(range(FLAGS.epochs)): _x, _y = input_fn(X_test, y_test, batch_size=batch_size) net.evaluate(_x, _y) X_train, train_coords = convertCluttered( mnist.train.images, finalImgSize=FLAGS.frame_size, number_patches=FLAGS.number_patches) y_train = mnist.train.labels train_coords = np.array([ gkern(coord[0], coord[1], kernlen=kernlen) for coord in train_coords ]) # print(net.confusion_matrix(_x, _y)) net.save() X_train, y_train, train_coords = shuffle_in_unison( X_train, y_train, train_coords) for i in range(0, len(X_train), batch_size): _x, _y, _train_coords = input_fn(X_train, y_train, train_coords, batch_size=batch_size) net.feedback_train(_x, _y, _train_coords, dropout=FLAGS.dropout)
class Trainer(BaseTrainer): def __init__(self, config): super(Trainer, self).__init__(config) self.datamanager = DataManger(config["data"]) # model self.model = Baseline( num_classes=self.datamanager.datasource.get_num_classes("train") ) # summary model summary( self.model, input_size=(3, 256, 128), batch_size=config["data"]["batch_size"], device="cpu", ) # losses cfg_losses = config["losses"] self.criterion = Softmax_Triplet_loss( num_class=self.datamanager.datasource.get_num_classes("train"), margin=cfg_losses["margin"], epsilon=cfg_losses["epsilon"], use_gpu=self.use_gpu, ) self.center_loss = CenterLoss( num_classes=self.datamanager.datasource.get_num_classes("train"), feature_dim=2048, use_gpu=self.use_gpu, ) # optimizer cfg_optimizer = config["optimizer"] self.optimizer = torch.optim.Adam( self.model.parameters(), lr=cfg_optimizer["lr"], weight_decay=cfg_optimizer["weight_decay"], ) self.optimizer_centerloss = torch.optim.SGD( self.center_loss.parameters(), lr=0.5 ) # learing rate scheduler cfg_lr_scheduler = config["lr_scheduler"] self.lr_scheduler = WarmupMultiStepLR( self.optimizer, milestones=cfg_lr_scheduler["steps"], gamma=cfg_lr_scheduler["gamma"], warmup_factor=cfg_lr_scheduler["factor"], warmup_iters=cfg_lr_scheduler["iters"], warmup_method=cfg_lr_scheduler["method"], ) # track metric self.train_metrics = MetricTracker("loss", "accuracy") self.valid_metrics = MetricTracker("loss", "accuracy") # save best accuracy for function _save_checkpoint self.best_accuracy = None # send model to device self.model.to(self.device) self.scaler = GradScaler() # resume model from last checkpoint if config["resume"] != "": self._resume_checkpoint(config["resume"]) def train(self): for epoch in range(self.start_epoch, self.epochs + 1): result = self._train_epoch(epoch) if self.lr_scheduler is not None: self.lr_scheduler.step() result = self._valid_epoch(epoch) # add scalars to tensorboard self.writer.add_scalars( "Loss", { "Train": self.train_metrics.avg("loss"), "Val": self.valid_metrics.avg("loss"), }, global_step=epoch, ) self.writer.add_scalars( "Accuracy", { "Train": self.train_metrics.avg("accuracy"), "Val": self.valid_metrics.avg("accuracy"), }, global_step=epoch, ) # logging result to console log = {"epoch": epoch} log.update(result) for key, value in log.items(): self.logger.info(" {:15s}: {}".format(str(key), value)) # save model if ( self.best_accuracy == None or self.best_accuracy < self.valid_metrics.avg("accuracy") ): self.best_accuracy = self.valid_metrics.avg("accuracy") self._save_checkpoint(epoch, save_best=True) else: self._save_checkpoint(epoch, save_best=False) # save logs self._save_logs(epoch) def _train_epoch(self, epoch): """Training step""" self.model.train() self.train_metrics.reset() with tqdm(total=len(self.datamanager.get_dataloader("train"))) as epoch_pbar: epoch_pbar.set_description(f"Epoch {epoch}") for batch_idx, (data, labels, _) in enumerate( self.datamanager.get_dataloader("train") ): # push data to device data, labels = data.to(self.device), labels.to(self.device) # zero gradient self.optimizer.zero_grad() self.optimizer_centerloss.zero_grad() with autocast(): # forward batch score, feat = self.model(data) # calculate loss and accuracy loss = ( self.criterion(score, feat, labels) + self.center_loss(feat, labels) * self.config["losses"]["beta"] ) _, preds = torch.max(score.data, dim=1) # backward parameters # loss.backward() self.scaler.scale(loss).backward() # backward parameters for center_loss for param in self.center_loss.parameters(): param.grad.data *= 1.0 / self.config["losses"]["beta"] # optimize # self.optimizer.step() self.scaler.step(self.optimizer) self.optimizer_centerloss.step() self.scaler.update() # update loss and accuracy in MetricTracker self.train_metrics.update("loss", loss.item()) self.train_metrics.update( "accuracy", torch.sum(preds == labels.data).double().item() / data.size(0), ) # update process bar epoch_pbar.set_postfix( { "train_loss": self.train_metrics.avg("loss"), "train_acc": self.train_metrics.avg("accuracy"), } ) epoch_pbar.update(1) return self.train_metrics.result() def _valid_epoch(self, epoch): """Validation step""" self.model.eval() self.valid_metrics.reset() with torch.no_grad(): with tqdm(total=len(self.datamanager.get_dataloader("val"))) as epoch_pbar: epoch_pbar.set_description(f"Epoch {epoch}") for batch_idx, (data, labels, _) in enumerate( self.datamanager.get_dataloader("val") ): # push data to device data, labels = data.to(self.device), labels.to(self.device) with autocast(): # forward batch score, feat = self.model(data) # calculate loss and accuracy loss = ( self.criterion(score, feat, labels) + self.center_loss(feat, labels) * self.config["losses"]["beta"] ) _, preds = torch.max(score.data, dim=1) # update loss and accuracy in MetricTracker self.valid_metrics.update("loss", loss.item()) self.valid_metrics.update( "accuracy", torch.sum(preds == labels.data).double().item() / data.size(0), ) # update process bar epoch_pbar.set_postfix( { "val_loss": self.valid_metrics.avg("loss"), "val_acc": self.valid_metrics.avg("accuracy"), } ) epoch_pbar.update(1) return self.valid_metrics.result() def _save_checkpoint(self, epoch, save_best=True): """save model to file""" state = { "epoch": epoch, "state_dict": self.model.state_dict(), "center_loss": self.center_loss.state_dict(), "optimizer": self.optimizer.state_dict(), "optimizer_centerloss": self.optimizer_centerloss.state_dict(), "lr_scheduler": self.lr_scheduler.state_dict(), "best_accuracy": self.best_accuracy, } filename = os.path.join(self.checkpoint_dir, "model_last.pth") self.logger.info("Saving last model: model_last.pth ...") torch.save(state, filename) if save_best: filename = os.path.join(self.checkpoint_dir, "model_best.pth") self.logger.info("Saving current best: model_best.pth ...") torch.save(state, filename) def _resume_checkpoint(self, resume_path): """Load model from checkpoint""" if not os.path.exists(resume_path): raise FileExistsError("Resume path not exist!") self.logger.info("Loading checkpoint: {} ...".format(resume_path)) checkpoint = torch.load(resume_path, map_location=self.map_location) self.start_epoch = checkpoint["epoch"] + 1 self.model.load_state_dict(checkpoint["state_dict"]) self.center_loss.load_state_dict(checkpoint["center_loss"]) self.optimizer.load_state_dict(checkpoint["optimizer"]) self.optimizer_centerloss.load_state_dict(checkpoint["optimizer_centerloss"]) self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) self.best_accuracy = checkpoint["best_accuracy"] self.logger.info( "Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch) ) def _save_logs(self, epoch): """Save logs from google colab to google drive""" if os.path.isdir(self.logs_dir_saved): shutil.rmtree(self.logs_dir_saved) destination = shutil.copytree(self.logs_dir, self.logs_dir_saved)