def evaluate_epoch(self, epoch): LOSSES_NAME = self.args.LOSSES_NAME epoch_results = {} for loss_name in LOSSES_NAME: epoch_results[loss_name] = 0. epoch_results[f'{loss_name}_count'] = 0 uid2ans = {} self.model.eval() with torch.no_grad(): if self.verbose: loss_meter = LossMeter() loss_meters = [LossMeter() for _ in range(len(LOSSES_NAME))] pbar = tqdm(total=len(self.val_loader), ncols=250) for step_i, batch in enumerate(self.val_loader): if self.args.distributed: results = self.model.module.valid_step(batch) else: results = self.model.valid_step(batch) if 'qa' in self.args.losses: qa_pred = results['qa_pred'] for uid, ans in zip(batch['uid'], qa_pred): uid2ans[uid] = ans for k, v in results.items(): if k in epoch_results: if isinstance(v, int): epoch_results[k] += v elif isinstance(v, torch.Tensor): epoch_results[k] += v.item() if self.verbose: desc_str = f'Valid Epoch {epoch} |' for i, (loss_name, loss_meter) in enumerate(zip(LOSSES_NAME, loss_meters)): if loss_name in results: loss_meter.update(results[f'{loss_name}'] / results[f'{loss_name}_count']) if len(loss_meter) > 0: loss_count = epoch_results[f'{loss_name}_count'] desc_str += f' {loss_name} ({loss_count}) {loss_meter.val:.3f}' pbar.set_description(desc_str) pbar.update(1) dist.barrier() if self.verbose: pbar.close() dist.barrier() if 'qa' not in self.args.losses: uid2ans = None return epoch_results, uid2ans
def __init__(self, args): self.args = args self.args.n_datasets = len(args.data) self.modelPath = Path('checkpoints') / args.expName self.logger = create_output_dir(args, self.modelPath) self.data = [DatasetSet(d, args.seq_len, args) for d in args.data] self.losses_recon = [ LossMeter(f'recon {i}') for i in range(self.args.n_datasets) ] self.loss_total = LossMeter('total') self.evals_recon = [ LossMeter(f'recon {i}') for i in range(self.args.n_datasets) ] self.eval_total = LossMeter('eval total') self.start_epoch = 0 #torch.manual_seed(args.seed) #torch.cuda.manual_seed(args.seed) #get the pretrained model checkpoints checkpoint = args.checkpoint.parent.glob(args.checkpoint.name + '_*.pth') checkpoint = [c for c in checkpoint if extract_id(c) in args.decoder][0] model_args = torch.load(args.checkpoint.parent / 'args.pth')[0] self.encoder = Encoder(model_args) self.decoder = WaveNet(model_args) self.encoder = Encoder(model_args) self.encoder.load_state_dict(torch.load(checkpoint)['encoder_state']) #encoder freeze for param in self.encoder.parameters(): param.requires_grad = False #self.logger.debug(f'encoder at start: {param}') self.decoder = WaveNet(model_args) self.decoder.load_state_dict(torch.load(checkpoint)['decoder_state']) #decoder freeze for param in self.decoder.layers[:-args.decoder_update].parameters(): param.requires_grad = False #self.logger.debug(f'decoder at start: {param}') self.encoder = torch.nn.DataParallel(self.encoder).cuda() self.decoder = torch.nn.DataParallel(self.decoder).cuda() self.model_optimizer = optim.Adam(chain(self.encoder.parameters(), self.decoder.parameters()), lr=args.lr) self.lr_manager = torch.optim.lr_scheduler.ExponentialLR( self.model_optimizer, args.lr_decay) self.lr_manager.step()
def __init__(self, args): self.args = args self.args.n_datasets = len(self.args.data) self.expPath = Path('checkpoints') / args.expName torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) self.logger = create_output_dir(args, self.expPath) self.data = [DatasetSet(d, args.seq_len, args) for d in args.data] self.losses_recon = [ LossMeter(f'recon {i}') for i in range(self.args.n_datasets) ] self.loss_total = LossMeter('total') self.evals_recon = [ LossMeter(f'recon {i}') for i in range(self.args.n_datasets) ] self.eval_total = LossMeter('eval total') self.encoder = Encoder(args) self.decoder = WaveNet(args) assert args.checkpoint, 'you MUST pass a checkpoint for the encoder' if args.continue_training: checkpoint_args_path = os.path.dirname( args.checkpoint) + '/args.pth' checkpoint_args = torch.load(checkpoint_args_path) self.start_epoch = checkpoint_args[-1] + 1 else: self.start_epoch = 0 states = torch.load(args.checkpoint) self.encoder.load_state_dict(states['encoder_state']) if args.continue_training: self.decoder.load_state_dict(states['decoder_state']) self.logger.info('Loaded checkpoint parameters') self.encoder = torch.nn.DataParallel(self.encoder).cuda() self.decoder = torch.nn.DataParallel(self.decoder).cuda() self.model_optimizer = optim.Adam(self.decoder.parameters(), lr=args.lr) if args.continue_training: self.model_optimizer.load_state_dict( states['model_optimizer_state']) self.lr_manager = torch.optim.lr_scheduler.ExponentialLR( self.model_optimizer, args.lr_decay) self.lr_manager.last_epoch = self.start_epoch self.lr_manager.step()
class Trainer: def __init__(self, args): self.args = args self.args.n_datasets = len(self.args.data) self.expPath = Path('checkpoints') / args.expName torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) self.logger = create_output_dir(args, self.expPath) self.data = [DatasetSet(d, args.seq_len, args) for d in args.data] self.losses_recon = [ LossMeter(f'recon {i}') for i in range(self.args.n_datasets) ] self.loss_total = LossMeter('total') self.evals_recon = [ LossMeter(f'recon {i}') for i in range(self.args.n_datasets) ] self.eval_total = LossMeter('eval total') self.encoder = Encoder(args) self.decoder = WaveNet(args) assert args.checkpoint, 'you MUST pass a checkpoint for the encoder' if args.continue_training: checkpoint_args_path = os.path.dirname( args.checkpoint) + '/args.pth' checkpoint_args = torch.load(checkpoint_args_path) self.start_epoch = checkpoint_args[-1] + 1 else: self.start_epoch = 0 states = torch.load(args.checkpoint) self.encoder.load_state_dict(states['encoder_state']) if args.continue_training: self.decoder.load_state_dict(states['decoder_state']) self.logger.info('Loaded checkpoint parameters') self.encoder = torch.nn.DataParallel(self.encoder).cuda() self.decoder = torch.nn.DataParallel(self.decoder).cuda() self.model_optimizer = optim.Adam(self.decoder.parameters(), lr=args.lr) if args.continue_training: self.model_optimizer.load_state_dict( states['model_optimizer_state']) self.lr_manager = torch.optim.lr_scheduler.ExponentialLR( self.model_optimizer, args.lr_decay) self.lr_manager.last_epoch = self.start_epoch self.lr_manager.step() def eval_batch(self, x, x_aug, dset_num): x, x_aug = x.float(), x_aug.float() z = self.encoder(x) y = self.decoder(x, z) recon_loss = cross_entropy_loss(y, x) self.evals_recon[dset_num].add(recon_loss.data.cpu().numpy().mean()) total_loss = recon_loss.mean().data.item() self.eval_total.add(total_loss) return total_loss def train_batch(self, x, x_aug, dset_num): x, x_aug = x.float(), x_aug.float() # optimize G - reconstructs well z = self.encoder(x_aug) z = z.detach() # stop gradients y = self.decoder(x, z) recon_loss = cross_entropy_loss(y, x) self.losses_recon[dset_num].add(recon_loss.data.cpu().numpy().mean()) loss = recon_loss.mean() self.model_optimizer.zero_grad() loss.backward() if self.args.grad_clip is not None: clip_grad_value_(self.decoder.parameters(), self.args.grad_clip) self.model_optimizer.step() self.loss_total.add(loss.data.item()) return loss.data.item() def train_epoch(self, epoch): for meter in self.losses_recon: meter.reset() self.loss_total.reset() self.encoder.eval() self.decoder.train() n_batches = self.args.epoch_len with tqdm(total=n_batches, desc='Train epoch %d' % epoch) as train_enum: for batch_num in range(n_batches): if self.args.short and batch_num == 3: break dset_num = batch_num % self.args.n_datasets x, x_aug = next(self.data[dset_num].train_iter) x = wrap(x) x_aug = wrap(x_aug) batch_loss = self.train_batch(x, x_aug, dset_num) train_enum.set_description( f'Train (loss: {batch_loss:.2f}) epoch {epoch}') train_enum.update() def evaluate_epoch(self, epoch): for meter in self.evals_recon: meter.reset() self.eval_total.reset() self.encoder.eval() self.decoder.eval() n_batches = int(np.ceil(self.args.epoch_len / 10)) with tqdm(total=n_batches) as valid_enum, \ torch.no_grad(): for batch_num in range(n_batches): if self.args.short and batch_num == 10: break dset_num = batch_num % self.args.n_datasets x, x_aug = next(self.data[dset_num].valid_iter) x = wrap(x) x_aug = wrap(x_aug) batch_loss = self.eval_batch(x, x_aug, dset_num) valid_enum.set_description( f'Test (loss: {batch_loss:.2f}) epoch {epoch}') valid_enum.update() @staticmethod def format_losses(meters): losses = [meter.summarize_epoch() for meter in meters] return ', '.join('{:.4f}'.format(x) for x in losses) def train_losses(self): meters = [*self.losses_recon] return self.format_losses(meters) def eval_losses(self): meters = [*self.evals_recon] return self.format_losses(meters) def train(self): best_eval = float('inf') # Begin! for epoch in range(self.start_epoch, self.start_epoch + self.args.epochs): self.logger.info( f'Starting epoch, Rank {self.args.rank}, Dataset: {self.args.data[self.args.rank]}' ) self.train_epoch(epoch) self.evaluate_epoch(epoch) self.logger.info( f'Epoch %s Rank {self.args.rank} - Train loss: (%s), Test loss (%s)', epoch, self.train_losses(), self.eval_losses()) self.lr_manager.step() val_loss = self.eval_total.summarize_epoch() if val_loss < best_eval: self.save_model(f'bestmodel_{self.args.rank}.pth') best_eval = val_loss if not self.args.per_epoch: self.save_model(f'lastmodel_{self.args.rank}.pth') else: self.save_model(f'lastmodel_{epoch}_rank_{self.args.rank}.pth') torch.save([self.args, epoch], '%s/args.pth' % self.expPath) self.logger.debug('Ended epoch') def save_model(self, filename): save_path = self.expPath / filename states = torch.load(self.args.checkpoint) torch.save( { 'encoder_state': states['encoder_state'], 'decoder_state': self.decoder.module.state_dict(), 'model_optimizer_state': self.model_optimizer.state_dict(), 'dataset': self.args.rank, }, save_path) self.logger.debug(f'Saved model to {save_path}')
def train(self): if self.verbose: loss_meter = LossMeter() best_valid = 0. best_epoch = 0 if 't5' in self.args.backbone: if self.args.use_vision: project_name = "VLT5_VQA" else: project_name = "T5_VQA" elif 'bart' in self.args.backbone: if self.args.use_vision: project_name = "VLBart_VQA" else: project_name = "Bart_VQA" wandb.init(project=project_name) wandb.run.name = self.args.run_name wandb.config.update(self.args) wandb.watch(self.model) src_dir = Path(__file__).resolve().parent base_path = str(src_dir.parent) src_dir = str(src_dir) wandb.save(os.path.join(src_dir + "/*.py"), base_path=base_path) if self.args.distributed: dist.barrier() global_step = 0 for epoch in range(self.args.epochs): if self.start_epoch is not None: epoch += self.start_epoch self.model.train() if self.args.distributed: self.train_loader.sampler.set_epoch(epoch) if self.verbose: pbar = tqdm(total=len(self.train_loader), ncols=120) epoch_results = { 'loss': 0., } quesid2ans = {} for step_i, batch in enumerate(self.train_loader): if self.args.fp16 and _use_native_amp: with autocast(): if self.args.distributed: results = self.model.module.train_step(batch) else: results = self.model.train_step(batch) else: if self.args.distributed: results = self.model.module.train_step(batch) else: results = self.model.train_step(batch) loss = results['loss'] if self.args.fp16 and _use_native_amp: self.scaler.scale(loss).backward() elif self.args.fp16 and _use_apex: with amp.scale_loss(loss, self.optim) as scaled_loss: scaled_loss.backward() else: loss.backward() loss = loss.detach() # Update Parameters if self.args.clip_grad_norm > 0: if self.args.fp16 and _use_native_amp: self.scaler.unscale_(self.optim) torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.args.clip_grad_norm) elif self.args.fp16 and _use_apex: torch.nn.utils.clip_grad_norm_(amp.master_params( self.optim), self.args.clip_grad_norm) else: torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.args.clip_grad_norm) if self.args.fp16 and _use_native_amp: self.scaler.step(self.optim) self.scaler.update() else: self.optim.step() if self.lr_scheduler: self.lr_scheduler.step() for param in self.model.parameters(): param.grad = None global_step += 1 for k, v in results.items(): if k in epoch_results: epoch_results[k] += v.item() if self.lr_scheduler: if version.parse(torch.__version__) >= version.parse("1.4"): lr = self.lr_scheduler.get_last_lr()[0] else: lr = self.lr_scheduler.get_lr()[0] else: try: lr = self.optim.get_lr()[0] except AttributeError: lr = self.args.lr if self.verbose: loss_meter.update(loss.item()) desc_str = f'Epoch {epoch} | LR {lr:.6f}' desc_str += f' | Loss {loss_meter.val:4f}' pbar.set_description(desc_str) pbar.update(1) if self.args.distributed: dist.barrier() if self.verbose: pbar.close() # Validation score_dict = self.evaluate(self.val_loader) if self.verbose: valid_score = score_dict['topk_score'] * 100. valid_score_raw = score_dict['overall'] if valid_score_raw > best_valid or epoch == 0: best_valid = valid_score_raw best_epoch = epoch self.save("BEST") log_str = '' log_str += "\nEpoch %d: Valid Raw %0.2f Topk %0.2f" % (epoch, valid_score_raw, valid_score) log_str += "\nEpoch %d: Best Raw %0.2f\n" % (best_epoch, best_valid) wandb_log_dict = {} wandb_log_dict['Train/Loss'] = epoch_results['loss'] / len(self.train_loader) wandb_log_dict['Valid/score'] = valid_score wandb_log_dict['Valid/raw_score'] = score_dict['overall'] for qtype, score in score_dict['perQuestionType'].items(): wandb_log_dict[f'Valid_Qtypes/{qtype}'] = score for atype, score in score_dict['perAnswerType'].items(): if atype == 'yes/no': atype = 'yes_no' wandb_log_dict[f'Valid_Atypes/{atype}'] = score wandb.log(wandb_log_dict, step=epoch) print(log_str) if self.args.distributed: dist.barrier() if self.verbose: self.save("LAST") # Test Set best_path = os.path.join(self.args.output, 'BEST') self.load(best_path) quesid2ans = self.predict(self.test_loader) if self.verbose: evaluator = self.test_loader.evaluator score_dict = evaluator.evaluate(quesid2ans) evaluator.dump_result(quesid2ans) acc_dict_all = evaluator.evaluate_raw(quesid2ans) acc_dict_answerable = evaluator.evaluate_raw(quesid2ans, is_topk_optimal=True) acc_dict_unanswerable = evaluator.evaluate_raw(quesid2ans, is_topk_optimal=False) wandb_log_dict = {} wandb_log_dict['Test/overall'] = acc_dict_all['overall'] wandb_log_dict['Test/topk_optimal'] = acc_dict_answerable['overall'] wandb_log_dict['Test/topk_not_optimal'] = acc_dict_unanswerable['overall'] for qtype, score in acc_dict_all['perQuestionType'].items(): wandb_log_dict[f'Test_Qtypes/{qtype}'] = score for atype, score in acc_dict_all['perAnswerType'].items(): if atype == 'yes/no': atype = 'yes_no' wandb_log_dict[f'Test_Atypes/{atype}'] = score print(wandb_log_dict) wandb.log(wandb_log_dict) if self.args.submit: dump_path = os.path.join(self.args.output, 'submit.json') self.predict(self.submit_test_loader, dump_path) wandb.save(dump_path, base_path=self.args.output) wandb.log({'finished': True}) if self.args.distributed: dist.barrier() exit()
def __init__(self, args): self.args = args self.args.n_datasets = len(self.args.data) self.expPath = Path('checkpoints') / args.expName torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) self.logger = create_output_dir(args, self.expPath) self.data = [DatasetSet(d, args.seq_len, args) for d in args.data] assert not args.distributed or len(self.data) == int( os.environ['WORLD_SIZE'] ), "Number of datasets must match number of nodes" self.losses_recon = [ LossMeter(f'recon {i}') for i in range(self.args.n_datasets) ] self.loss_d_right = LossMeter('d') self.loss_total = LossMeter('total') self.evals_recon = [ LossMeter(f'recon {i}') for i in range(self.args.n_datasets) ] self.eval_d_right = LossMeter('eval d') self.eval_total = LossMeter('eval total') self.encoder = Encoder(args) self.decoder = WaveNet(args) self.discriminator = ZDiscriminator(args) if args.checkpoint: checkpoint_args_path = os.path.dirname( args.checkpoint) + '/args.pth' checkpoint_args = torch.load(checkpoint_args_path) self.start_epoch = checkpoint_args[-1] + 1 states = torch.load(args.checkpoint) self.encoder.load_state_dict(states['encoder_state']) self.decoder.load_state_dict(states['decoder_state']) self.discriminator.load_state_dict(states['discriminator_state']) self.logger.info('Loaded checkpoint parameters') else: self.start_epoch = 0 if args.distributed: self.encoder.cuda() self.encoder = torch.nn.parallel.DistributedDataParallel( self.encoder) self.discriminator.cuda() self.discriminator = torch.nn.parallel.DistributedDataParallel( self.discriminator) self.logger.info('Created DistributedDataParallel') else: self.encoder = torch.nn.DataParallel(self.encoder).cuda() self.discriminator = torch.nn.DataParallel( self.discriminator).cuda() self.decoder = torch.nn.DataParallel(self.decoder).cuda() self.model_optimizer = optim.Adam(chain(self.encoder.parameters(), self.decoder.parameters()), lr=args.lr) self.d_optimizer = optim.Adam(self.discriminator.parameters(), lr=args.lr) if args.checkpoint and args.load_optimizer: self.model_optimizer.load_state_dict( states['model_optimizer_state']) self.d_optimizer.load_state_dict(states['d_optimizer_state']) self.lr_manager = torch.optim.lr_scheduler.ExponentialLR( self.model_optimizer, args.lr_decay) self.lr_manager.last_epoch = self.start_epoch self.lr_manager.step()
class Trainer: def __init__(self, args): self.args = args self.args.n_datasets = len(self.args.data) self.expPath = Path('checkpoints') / args.expName torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) self.logger = create_output_dir(args, self.expPath) self.data = [DatasetSet(d, args.seq_len, args) for d in args.data] assert not args.distributed or len(self.data) == int( os.environ['WORLD_SIZE'] ), "Number of datasets must match number of nodes" self.losses_recon = [ LossMeter(f'recon {i}') for i in range(self.args.n_datasets) ] self.loss_d_right = LossMeter('d') self.loss_total = LossMeter('total') self.evals_recon = [ LossMeter(f'recon {i}') for i in range(self.args.n_datasets) ] self.eval_d_right = LossMeter('eval d') self.eval_total = LossMeter('eval total') self.encoder = Encoder(args) self.decoder = WaveNet(args) self.discriminator = ZDiscriminator(args) if args.checkpoint: checkpoint_args_path = os.path.dirname( args.checkpoint) + '/args.pth' checkpoint_args = torch.load(checkpoint_args_path) self.start_epoch = checkpoint_args[-1] + 1 states = torch.load(args.checkpoint) self.encoder.load_state_dict(states['encoder_state']) self.decoder.load_state_dict(states['decoder_state']) self.discriminator.load_state_dict(states['discriminator_state']) self.logger.info('Loaded checkpoint parameters') else: self.start_epoch = 0 if args.distributed: self.encoder.cuda() self.encoder = torch.nn.parallel.DistributedDataParallel( self.encoder) self.discriminator.cuda() self.discriminator = torch.nn.parallel.DistributedDataParallel( self.discriminator) self.logger.info('Created DistributedDataParallel') else: self.encoder = torch.nn.DataParallel(self.encoder).cuda() self.discriminator = torch.nn.DataParallel( self.discriminator).cuda() self.decoder = torch.nn.DataParallel(self.decoder).cuda() self.model_optimizer = optim.Adam(chain(self.encoder.parameters(), self.decoder.parameters()), lr=args.lr) self.d_optimizer = optim.Adam(self.discriminator.parameters(), lr=args.lr) if args.checkpoint and args.load_optimizer: self.model_optimizer.load_state_dict( states['model_optimizer_state']) self.d_optimizer.load_state_dict(states['d_optimizer_state']) self.lr_manager = torch.optim.lr_scheduler.ExponentialLR( self.model_optimizer, args.lr_decay) self.lr_manager.last_epoch = self.start_epoch self.lr_manager.step() def eval_batch(self, x, x_aug, dset_num): x, x_aug = x.float(), x_aug.float() z = self.encoder(x) y = self.decoder(x, z) z_logits = self.discriminator(z) z_classification = torch.max(z_logits, dim=1)[1] z_accuracy = (z_classification == dset_num).float().mean() self.eval_d_right.add(z_accuracy.data.item()) # discriminator_right = F.cross_entropy(z_logits, dset_num).mean() discriminator_right = F.cross_entropy( z_logits, torch.tensor([dset_num] * x.size(0)).long().cuda()).mean() recon_loss = cross_entropy_loss(y, x) self.evals_recon[dset_num].add(recon_loss.data.cpu().numpy().mean()) total_loss = discriminator_right.data.item() * self.args.d_lambda + \ recon_loss.mean().data.item() self.eval_total.add(total_loss) return total_loss def train_batch(self, x, x_aug, dset_num): x, x_aug = x.float(), x_aug.float() # Optimize D - discriminator right z = self.encoder(x) z_logits = self.discriminator(z) discriminator_right = F.cross_entropy( z_logits, torch.tensor([dset_num] * x.size(0)).long().cuda()).mean() loss = discriminator_right * self.args.d_lambda self.d_optimizer.zero_grad() loss.backward() if self.args.grad_clip is not None: clip_grad_value_(self.discriminator.parameters(), self.args.grad_clip) self.d_optimizer.step() # optimize G - reconstructs well, discriminator wrong z = self.encoder(x_aug) y = self.decoder(x, z) z_logits = self.discriminator(z) discriminator_wrong = -F.cross_entropy( z_logits, torch.tensor([dset_num] * x.size(0)).long().cuda()).mean() if not (-100 < discriminator_right.data.item() < 100): self.logger.debug(f'z_logits: {z_logits.detach().cpu().numpy()}') self.logger.debug(f'dset_num: {dset_num}') recon_loss = cross_entropy_loss(y, x) self.losses_recon[dset_num].add(recon_loss.data.cpu().numpy().mean()) loss = (recon_loss.mean() + self.args.d_lambda * discriminator_wrong) self.model_optimizer.zero_grad() loss.backward() if self.args.grad_clip is not None: clip_grad_value_(self.encoder.parameters(), self.args.grad_clip) clip_grad_value_(self.decoder.parameters(), self.args.grad_clip) self.model_optimizer.step() self.loss_total.add(loss.data.item()) return loss.data.item() def train_epoch(self, epoch): for meter in self.losses_recon: meter.reset() self.loss_d_right.reset() self.loss_total.reset() self.encoder.train() self.decoder.train() self.discriminator.train() n_batches = self.args.epoch_len with tqdm(total=n_batches, desc='Train epoch %d' % epoch) as train_enum: for batch_num in range(n_batches): if self.args.short and batch_num == 3: break if self.args.distributed: assert self.args.rank < self.args.n_datasets, "No. of workers must be equal to #dataset" # dset_num = (batch_num + self.args.rank) % self.args.n_datasets dset_num = self.args.rank else: dset_num = batch_num % self.args.n_datasets x, x_aug = next(self.data[dset_num].train_iter) x = wrap(x) x_aug = wrap(x_aug) batch_loss = self.train_batch(x, x_aug, dset_num) train_enum.set_description( f'Train (loss: {batch_loss:.2f}) epoch {epoch}') train_enum.update() def evaluate_epoch(self, epoch): for meter in self.evals_recon: meter.reset() self.eval_d_right.reset() self.eval_total.reset() self.encoder.eval() self.decoder.eval() self.discriminator.eval() n_batches = int(np.ceil(self.args.epoch_len / 10)) with tqdm(total=n_batches) as valid_enum, \ torch.no_grad(): for batch_num in range(n_batches): if self.args.short and batch_num == 10: break if self.args.distributed: assert self.args.rank < self.args.n_datasets, "No. of workers must be equal to #dataset" dset_num = self.args.rank else: dset_num = batch_num % self.args.n_datasets x, x_aug = next(self.data[dset_num].valid_iter) x = wrap(x) x_aug = wrap(x_aug) batch_loss = self.eval_batch(x, x_aug, dset_num) valid_enum.set_description( f'Test (loss: {batch_loss:.2f}) epoch {epoch}') valid_enum.update() @staticmethod def format_losses(meters): losses = [meter.summarize_epoch() for meter in meters] return ', '.join('{:.4f}'.format(x) for x in losses) def train_losses(self): meters = [*self.losses_recon, self.loss_d_right] return self.format_losses(meters) def eval_losses(self): meters = [*self.evals_recon, self.eval_d_right] return self.format_losses(meters) def train(self): best_eval = float('inf') # Begin! for epoch in range(self.start_epoch, self.start_epoch + self.args.epochs): self.logger.info( f'Starting epoch, Rank {self.args.rank}, Dataset: {self.args.data[self.args.rank]}' ) self.train_epoch(epoch) self.evaluate_epoch(epoch) self.logger.info( f'Epoch %s Rank {self.args.rank} - Train loss: (%s), Test loss (%s)', epoch, self.train_losses(), self.eval_losses()) self.lr_manager.step() val_loss = self.eval_total.summarize_epoch() if val_loss < best_eval: self.save_model(f'bestmodel_{self.args.rank}.pth') best_eval = val_loss if not self.args.per_epoch: self.save_model(f'lastmodel_{self.args.rank}.pth') else: self.save_model(f'lastmodel_{epoch}_rank_{self.args.rank}.pth') if self.args.is_master: torch.save([self.args, epoch], '%s/args.pth' % self.expPath) self.logger.debug('Ended epoch') def save_model(self, filename): save_path = self.expPath / filename torch.save( { 'encoder_state': self.encoder.module.state_dict(), 'decoder_state': self.decoder.module.state_dict(), 'discriminator_state': self.discriminator.module.state_dict(), 'model_optimizer_state': self.model_optimizer.state_dict(), 'dataset': self.args.rank, 'd_optimizer_state': self.d_optimizer.state_dict() }, save_path) self.logger.debug(f'Saved model to {save_path}')
def train(self): if self.verbose: n_correct = 0 n_total = 0 for batch in self.val_loader: exists_target = batch['exists_target'] n_correct += exists_target.sum().item() n_total += len(exists_target) print(f'Val Oracle acc: {n_correct / n_total * 100:.2f}%') n_correct = 0 n_total = 0 for batch in self.test_loader: exists_target = batch['exists_target'] n_correct += exists_target.sum().item() n_total += len(exists_target) print(f'Test Oracle acc: {n_correct / n_total * 100:.2f}%') if self.verbose: loss_meter = LossMeter() best_valid_acc = 0. best_epoch = 0 if 't5' in self.args.backbone: project_name = "VLT5_RefCOCOg" elif 'bart' in self.args.backbone: project_name = "VLBart_RefCOCOg" if self.args.RefCOCO_GT: project_name += '_GT' wandb.init(project=project_name) wandb.run.name = self.args.run_name wandb.config.update(self.args) wandb.watch(self.model) src_dir = Path(__file__).resolve().parent base_path = str(src_dir.parent) src_dir = str(src_dir) wandb.save(os.path.join(src_dir + "/*.py"), base_path=base_path) if self.args.distributed: dist.barrier() global_step = 0 for epoch in range(self.args.epochs): if self.start_epoch is not None: epoch += self.start_epoch self.model.train() if self.args.distributed: self.train_loader.sampler.set_epoch(epoch) if self.verbose: pbar = tqdm(total=len(self.train_loader), ncols=120) epoch_results = { 'loss': 0, } n_correct = 0 n_total = 0 for step_i, batch in enumerate(self.train_loader): batch['log_train_accuracy'] = self.args.log_train_accuracy if self.args.fp16 and _use_native_amp: with autocast(): if self.args.distributed: results = self.model.module.train_step(batch) else: results = self.model.train_step(batch) else: if self.args.distributed: results = self.model.module.train_step(batch) else: results = self.model.train_step(batch) loss = results['loss'] if self.args.fp16 and _use_native_amp: self.scaler.scale(loss).backward() elif self.args.fp16 and _use_apex: with amp.scale_loss(loss, self.optim) as scaled_loss: scaled_loss.backward() else: loss.backward() loss = loss.detach() # Update Parameters if self.args.clip_grad_norm > 0: if self.args.fp16 and _use_native_amp: self.scaler.unscale_(self.optim) torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.args.clip_grad_norm) elif self.args.fp16 and _use_apex: torch.nn.utils.clip_grad_norm_( amp.master_params(self.optim), self.args.clip_grad_norm) else: torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.args.clip_grad_norm) if self.args.fp16 and _use_native_amp: self.scaler.step(self.optim) self.scaler.update() else: self.optim.step() if self.lr_scheduler: self.lr_scheduler.step() for param in self.model.parameters(): param.grad = None global_step += 1 for k, v in results.items(): if k in epoch_results: epoch_results[k] += v.item() if self.lr_scheduler: if version.parse( torch.__version__) >= version.parse("1.4"): lr = self.lr_scheduler.get_last_lr()[0] else: lr = self.lr_scheduler.get_lr()[0] else: try: lr = self.optim.get_lr()[0] except AttributeError: lr = self.args.lr if self.args.log_train_accuracy: correct = results['correct'] n_correct += sum(correct) n_total += len(correct) if self.verbose: loss_meter.update(loss.item()) # acc_meter.update(results['acc'].item()) desc_str = f'Epoch {epoch} | LR {lr:.6f} | ' desc_str += f'Loss {loss_meter.val:.3f} |' # desc_str += f' Acc {acc_meter.val:.3f} |' if self.args.log_train_accuracy: desc_str += f' Correct {n_correct:.0f}' desc_str += f' (Acc {n_correct/n_total*100:.1f}%)' pbar.set_description(desc_str) pbar.update(1) if self.args.distributed: dist.barrier() if self.verbose: pbar.close() if self.args.log_train_accuracy: train_score_dict = {'n_correct': n_correct, 'n_total': n_total} train_score_dict = reduce_dict(train_score_dict, self.args.gpu) # Validation # valid_score_dict = self.evaluate(self.val_loader) # valid_score_dict = reduce_dict(valid_score_dict, self.args.gpu) if self.verbose: if self.args.log_train_accuracy: train_acc = train_score_dict[ 'n_correct'] / train_score_dict['n_total'] * 100 train_n_correct = int(train_score_dict['n_correct']) train_n_total = int(train_score_dict['n_total']) # Validation valid_score_dict = self.evaluate(self.val_loader) valid_acc = valid_score_dict['n_correct'] / valid_score_dict[ 'n_total'] * 100 valid_n_correct = int(valid_score_dict['n_correct']) valid_n_total = int(valid_score_dict['n_total']) if valid_acc > best_valid_acc: best_valid_acc = valid_acc best_epoch = epoch self.save("BEST") log_str = '' if self.args.log_train_accuracy: log_str += f"\nEpoch {epoch}: Train" log_str += f" Acc {train_acc:.2f}% |" log_str += f" # correct {train_n_correct} # total {train_n_total}" log_str += f"\nEpoch {epoch}: Valid" log_str += f" Acc {valid_acc:.2f}% |" log_str += f" # correct {valid_n_correct} # total {valid_n_total}" log_str += f"\nEpoch {best_epoch}: Best Acc {best_valid_acc:.2f}%\n" wandb_log_dict = {} if self.args.log_train_accuracy: wandb_log_dict['Train/Acc'] = train_acc wandb_log_dict['Valid/Acc'] = valid_acc wandb.log(wandb_log_dict, step=epoch) print(log_str) if self.args.distributed: dist.barrier() if self.verbose: self.save("LAST") # Test Set best_path = os.path.join(self.args.output, 'BEST') self.load(best_path) test_score_dict = self.evaluate(self.test_loader) test_acc = test_score_dict['n_correct'] / test_score_dict[ 'n_total'] * 100 test_n_correct = int(test_score_dict['n_correct']) test_n_total = int(test_score_dict['n_total']) wandb_log_dict = {} wandb_log_dict['Test/Acc'] = test_acc wandb.log(wandb_log_dict, step=epoch) log_str = '' log_str += f"\nTest Acc {test_acc:.2f}%" log_str += f"\nTest # correct {test_n_correct} # total {test_n_total}" print(log_str) wandb.log({'finished': True}) if self.args.distributed: dist.barrier()
def train(self): if self.verbose: loss_meter = LossMeter() best_valid = 0. best_epoch = 0 if 't5' in self.args.backbone: if self.args.use_vision: project_name = "VLT5_GQA" else: project_name = "T5_GQA" elif 'bart' in self.args.backbone: if self.args.use_vision: project_name = "VLBart_GQA" else: project_name = "Bart_GQA" wandb.init(project=project_name) wandb.run.name = self.args.run_name wandb.config.update(self.args) wandb.watch(self.model) src_dir = Path(__file__).resolve().parent base_path = str(src_dir.parent) src_dir = str(src_dir) wandb.save(os.path.join(src_dir + "/*.py"), base_path=base_path) if self.args.distributed: dist.barrier() # torch.autograd.set_detect_anomaly(True) # print(f'GPU{self.args.gpu} before training starts') global_step = 0 for epoch in range(self.args.epochs): if self.start_epoch is not None: epoch += self.start_epoch self.model.train() if self.args.distributed: self.train_loader.sampler.set_epoch(epoch) if self.verbose: pbar = tqdm(total=len(self.train_loader), ncols=120) epoch_results = { 'loss': 0., } quesid2ans = {} train_acc = 0. train_acc_steps = int(len(self.train_loader) * 0.05) last_acc_step = 0 # print(f'GPU{self.args.gpu} before training loop') for step_i, batch in enumerate(self.train_loader): # print(f'GPU{self.args.gpu} inside training loop') # print(batch) # self.optim.zero_grad() if self.args.fp16 and _use_native_amp: with autocast(): if self.args.distributed: results = self.model.module.train_step(batch) else: results = self.model.train_step(batch) else: if self.args.distributed: results = self.model.module.train_step(batch) else: results = self.model.train_step(batch) loss = results['loss'] # print(f'GPU{self.args.gpu} after loss') if self.args.fp16 and _use_native_amp: self.scaler.scale(loss).backward() elif self.args.fp16 and _use_apex: with amp.scale_loss(loss, self.optim) as scaled_loss: scaled_loss.backward() else: loss.backward() # print(f'GPU{self.args.gpu} after backward') loss = loss.detach() # Update Parameters if self.args.clip_grad_norm > 0: if self.args.fp16 and _use_native_amp: self.scaler.unscale_(self.optim) torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.args.clip_grad_norm) elif self.args.fp16 and _use_apex: torch.nn.utils.clip_grad_norm_( amp.master_params(self.optim), self.args.clip_grad_norm) else: torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.args.clip_grad_norm) if self.args.fp16 and _use_native_amp: self.scaler.step(self.optim) self.scaler.update() else: self.optim.step() if self.lr_scheduler: self.lr_scheduler.step() # self.model.zero_grad() for param in self.model.parameters(): param.grad = None global_step += 1 for k, v in results.items(): if k in epoch_results: epoch_results[k] += v.item() if self.lr_scheduler: if version.parse( torch.__version__) >= version.parse("1.4"): lr = self.lr_scheduler.get_last_lr()[0] else: lr = self.lr_scheduler.get_lr()[0] else: try: lr = self.optim.get_lr()[0] except AttributeError: lr = self.args.lr if self.verbose: loss_meter.update(loss.item()) desc_str = f'Epoch {epoch} | LR {lr:.6f}' desc_str += f' | Loss {loss_meter.val:4f}' pbar.set_description(desc_str) pbar.update(1) if self.args.distributed: dist.barrier() if self.verbose: pbar.close() log_str = '' # Validation valid_score = self.evaluate(self.val_loader) * 100. if valid_score > best_valid: best_valid = valid_score best_epoch = epoch self.save("BEST") log_str += "\nEpoch %d: Testdev %0.2f" % (epoch, valid_score) log_str += "\nEpoch %d: Best %0.2f\n" % (best_epoch, best_valid) wandb_log_dict = {} wandb_log_dict['Train/Loss'] = epoch_results['loss'] / len( self.train_loader) wandb_log_dict['Testdev/score'] = valid_score wandb.log(wandb_log_dict, step=epoch) print(log_str) if self.args.distributed: dist.barrier() if self.verbose: self.save("LAST") # Test Set best_path = os.path.join(self.args.output, 'BEST') self.load(best_path) dump_path = os.path.join(self.args.output, 'submit.json') self.predict(self.test_loader, dump_path=dump_path) wandb.save(dump_path, base_path=self.args.output) wandb.log({'finished': True}) if self.args.distributed: dist.barrier() exit()
def train(self): if self.verbose: loss_meter = LossMeter() qa_loss_meter = LossMeter() qar_loss_meter = LossMeter() best_valid_Q_AR = 0. best_epoch = 0 if 't5' in self.args.backbone: if self.args.use_vision: project_name = "VLT5_VCR" else: project_name = "T5_VCR" elif 'bart' in self.args.backbone: if self.args.use_vision: project_name = "VLBart_VCR" else: project_name = "Bart_VCR" wandb.init(project=project_name) wandb.run.name = self.args.run_name wandb.config.update(self.args) wandb.watch(self.model) src_dir = Path(__file__).resolve().parent base_path = str(src_dir.parent) src_dir = str(src_dir) wandb.save(os.path.join(src_dir + "/*.py"), base_path=base_path) if self.args.distributed: dist.barrier() global_step = 0 for epoch in range(self.args.epochs): if self.start_epoch is not None: epoch += self.start_epoch self.model.train() if self.args.distributed: self.train_loader.sampler.set_epoch(epoch) if self.verbose: pbar = tqdm(total=len(self.train_loader), ncols=200) epoch_results = { 'loss': 0, } Q_A_results = 0 QA_R_results = 0 Q_AR_results = 0 n_total = 0 n_accu = 0 train_loss = 0 train_qa_loss = 0 train_qar_loss = 0 for step_i, batch in enumerate(self.train_loader): batch['log_train_accuracy'] = self.args.log_train_accuracy if self.args.fp16 and _use_native_amp: with autocast(): if self.args.distributed: results = self.model.module.train_step(batch) else: results = self.model.train_step(batch) else: if self.args.distributed: results = self.model.module.train_step(batch) else: results = self.model.train_step(batch) loss = results['loss'] if self.args.gradient_accumulation_steps > 1: loss /= self.args.gradient_accumulation_steps if self.args.fp16 and _use_native_amp: self.scaler.scale(loss).backward() elif self.args.fp16 and _use_apex: with amp.scale_loss(loss, self.optim) as scaled_loss: scaled_loss.backward() else: loss.backward() # Update Parameters if self.args.clip_grad_norm > 0: if self.args.fp16 and _use_native_amp: self.scaler.unscale_(self.optim) torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.args.clip_grad_norm) elif self.args.fp16 and _use_apex: torch.nn.utils.clip_grad_norm_( amp.master_params(self.optim), self.args.clip_grad_norm) else: torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.args.clip_grad_norm) update = True if self.args.gradient_accumulation_steps > 1: # if step_i == 0: # update = False # elif step_i % self.args.gradient_accumulation_steps == 0 or step_i == len(self.train_loader) - 1: # update = True # else: # update = False update = ((step_i + 1) % self.args.gradient_accumulation_steps == 0) or (step_i == len(self.train_loader) - 1) n_accu += 1 if update: if self.args.fp16 and _use_native_amp: self.scaler.step(self.optim) self.scaler.update() else: self.optim.step() if self.lr_scheduler: self.lr_scheduler.step() # self.model.zero_grad() for param in self.model.parameters(): param.grad = None global_step += 1 for k, v in results.items(): if k in epoch_results: epoch_results[k] += v.item() if self.lr_scheduler: if version.parse( torch.__version__) >= version.parse("1.4"): lr = self.lr_scheduler.get_last_lr()[0] else: lr = self.lr_scheduler.get_lr()[0] else: try: lr = self.optim.get_lr()[0] except AttributeError: lr = self.args.lr if self.verbose: train_loss += loss.detach().item() train_qa_loss += results['qa_loss'].item( ) / self.args.gradient_accumulation_steps train_qar_loss += results['qar_loss'].item( ) / self.args.gradient_accumulation_steps if self.args.log_train_accuracy: qa_pred = results['qa_pred'] qar_pred = results['qar_pred'] a_labels = batch['answer_labels'] r_labels = batch['rationale_labels'] Q_A_correct = a_labels == qa_pred QA_R_correct = r_labels == qar_pred Q_AR_correct = Q_A_correct & QA_R_correct Q_A_results += sum(Q_A_correct) QA_R_results += sum(QA_R_correct) Q_AR_results += sum(Q_AR_correct) n_total += len(qa_pred) if update: if self.args.gradient_accumulation_steps > 1: train_loss *= self.args.gradient_accumulation_steps / n_accu train_qa_loss *= self.args.gradient_accumulation_steps / n_accu train_qar_loss *= self.args.gradient_accumulation_steps / n_accu loss_meter.update(train_loss) qa_loss_meter.update(train_qa_loss) qar_loss_meter.update(train_qar_loss) desc_str = f'Epoch {epoch} | LR {lr:.6f} | Steps {global_step} |' desc_str += f' Loss {loss_meter.val:.3f} |' desc_str += f' QA Loss {qa_loss_meter.val:.3f} |' desc_str += f' QAR Loss {qar_loss_meter.val:.3f} |' train_loss = 0 train_qa_loss = 0 train_qar_loss = 0 n_accu = 0 if self.args.log_train_accuracy: desc_str += f' Q -> A {Q_A_results} ({Q_A_results/n_total*100:.1f}%)' desc_str += f' QA -> R {QA_R_results} ({QA_R_results/n_total*100:.1f}%)' desc_str += f' Q -> AR {Q_AR_results} ({Q_AR_results/n_total*100:.1f}%)' pbar.set_description(desc_str) pbar.update(1) if self.verbose: pbar.close() if self.args.log_train_accuracy: train_score_dict = { 'Q_A': Q_A_results, 'QA_R': QA_R_results, 'Q_AR': Q_AR_results, 'n_total': n_total } train_score_dict = reduce_dict(train_score_dict, self.args.gpu) # Validation valid_score_dict = self.evaluate_val(self.val_loader) if self.verbose: if self.args.log_train_accuracy: train_Q_A = train_score_dict['Q_A'] / train_score_dict[ 'n_total'] * 100 train_QA_R = train_score_dict['QA_R'] / train_score_dict[ 'n_total'] * 100 train_Q_AR = train_score_dict['Q_AR'] / train_score_dict[ 'n_total'] * 100 train_n_total = int(train_score_dict['n_total']) valid_Q_A = valid_score_dict['Q_A'] / valid_score_dict[ 'n_total'] * 100 valid_QA_R = valid_score_dict['QA_R'] / valid_score_dict[ 'n_total'] * 100 valid_Q_AR = valid_score_dict['Q_AR'] / valid_score_dict[ 'n_total'] * 100 valid_n_total = int(valid_score_dict['n_total']) if valid_Q_AR > best_valid_Q_AR: best_valid_Q_AR = valid_Q_AR best_epoch = epoch self.save("BEST") log_str = '' if self.args.log_train_accuracy: log_str += f"\nEpoch {epoch}: Train |" log_str += f" # examples: {train_n_total} |" log_str += f" Q -> A {train_Q_A:.2f}%" log_str += f" QA -> R {train_QA_R:.2f}%" log_str += f" Q -> AR {train_Q_AR:.2f}%" log_str += f"\nEpoch {epoch}: Valid |" log_str += f" # examples: {valid_n_total} |" log_str += f" Q -> A {valid_Q_A:.2f}%" log_str += f" QA -> R {valid_QA_R:.2f}%" log_str += f" Q -> AR {valid_Q_AR:.2f}%" #log_str += "\nEpoch %d: Valid Q -> AR %0.2f" % (epoch, valid_Q_AR) log_str += f"\nBest Epoch {best_epoch}: Q -> AR {best_valid_Q_AR:.2f}%\n" wandb_log_dict = {} # wandb_log_dict['Train/Loss'] = loss_meter.val if self.args.log_train_accuracy: wandb_log_dict['Train/Q_A'] = train_Q_A wandb_log_dict['Train/QA_R'] = train_QA_R wandb_log_dict['Train/Q_AR'] = train_Q_AR wandb_log_dict['Valid/Q_A'] = valid_Q_A wandb_log_dict['Valid/QA_R'] = valid_QA_R wandb_log_dict['Valid/Q_AR'] = valid_Q_AR wandb.log(wandb_log_dict, step=epoch) print(log_str) if self.args.distributed: dist.barrier() if self.verbose: self.save("LAST") # Test Set best_path = os.path.join(self.args.output, 'BEST') self.load(best_path) if self.verbose: dump_path = os.path.join(self.args.output, 'test_submit.csv') print('Dumping test set results at', dump_path) self.evaluate_test(self.test_loader, dump_path=dump_path) wandb.save(dump_path, base_path=self.args.output) print('Done!') wandb.log({'finished': True}) if self.args.distributed: dist.barrier() exit()
def train(self): if self.verbose: loss_meter = LossMeter() best_valid = 0. from torch.utils.tensorboard import SummaryWriter self.writer = SummaryWriter(log_dir=self.args.log_dir) hparam_dict = {} for k, v in self.args.__dict__.items(): if type(v) in [int, float, str, bool, torch.Tensor]: hparam_dict[k] = v metric_dict = {} self.writer.add_hparams(hparam_dict, metric_dict) if self.args.distributed: dist.barrier() self.optim.zero_grad() for epoch in range(self.args.epochs): self.model.train() if self.args.distributed: self.train_loader.sampler.set_epoch(epoch) if self.verbose: pbar = tqdm(total=len(self.train_loader), ncols=150) results = np.zeros(4, dtype=int) quesid2ans = {} for step_i, batch in enumerate(self.train_loader): vis_feats = batch['vis_feats'].cuda() boxes = batch['boxes'].cuda() ques_id = batch['question_ids'] B = len(ques_id) input_ids = batch['word_ids'].cuda() input_ids = input_ids.unsqueeze(1).repeat(1, 2, 1).view(B * 2, -1) label = batch['labels'].cuda() results = self.model( input_ids=input_ids, visual_feats=vis_feats, visual_pos=boxes, attention_mask=input_ids > 0, ) logit = results['logit'] loss = self.mce_loss(logit, label) loss.backward() update = True if self.args.update_freq > 1: if step_i == 0: update = False elif step_i % self.args.update_freq == 0 or step_i == len( self.train_loader) - 1: update = True else: update = False if update: if not self.args.no_clip_grad: nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip_grad_norm) self.optim.step() self.lr_scheduler.step() for param in self.model.parameters(): param.grad = None try: lr = self.scheduler.get_last_lr()[0] except AttributeError: lr = self.args.lr if self.verbose: loss_meter.update(loss.item()) desc_str = f'Epoch {epoch} | LR {lr:.6f} | ' desc_str += f'Loss {loss_meter.val:4f} |' score, predict = logit.max(1) predict = predict.cpu().numpy() label = label.cpu().numpy() for qid, pred in zip(ques_id, predict): quesid2ans[qid] = pred results[0] += sum((label == 1) & (predict == 1)) results[1] += sum((label == 1) & (predict == 0)) results[2] += sum((label == 0) & (predict == 1)) results[3] += sum((label == 0) & (predict == 0)) n_total = sum(results) desc_str += f' TP {results[0]} ({results[0]/n_total*100:.1f}%)' desc_str += f' FN {results[1]} ({results[1]/n_total*100:.1f}%)' desc_str += f' FP {results[2]} ({results[2]/n_total*100:.1f}%)' desc_str += f' TN {results[3]} ({results[3]/n_total*100:.1f}%)' pbar.set_description(desc_str) pbar.update(1) if self.args.distributed: dist.barrier() if self.verbose: pbar.close() score = self.train_loader.evaluator.evaluate(quesid2ans) * 100. log_str = "\nEpoch %d: Train %0.2f" % (epoch, score) if not self.args.dry: self.writer.add_scalar(f'NLVR/Train/score', score, epoch) # Validation valid_score = self.evaluate(self.val_loader) * 100. if valid_score > best_valid: best_valid = valid_score self.save("BEST") log_str += "\nEpoch %d: Valid %0.2f" % (epoch, valid_score) log_str += "\nEpoch %d: Best %0.2f\n" % (epoch, best_valid) if not self.args.dry: self.writer.add_scalar(f'NLVR/Valid/score', valid_score, epoch) print(log_str) self.logger.info(log_str) if self.args.distributed: dist.barrier() if self.verbose: self.save("LAST")
def train(self): if self.verbose: loss_meter = LossMeter() # best_eval_loss = 9595. quesid2ans = {} best_valid = 0. best_epoch = 0 if 't5' in self.args.backbone: project_name = "VLT5_NLVR" elif 'bart' in self.args.backbone: project_name = "VLBART_NLVR" wandb.init(project=project_name) wandb.run.name = self.args.run_name wandb.config.update(self.args) wandb.watch(self.model) src_dir = Path(__file__).resolve().parent base_path = str(src_dir.parent) src_dir = str(src_dir) wandb.save(os.path.join(src_dir + "/*.py"), base_path=base_path) if self.args.distributed: dist.barrier() global_step = 0 for epoch in range(self.args.epochs): if self.start_epoch is not None: epoch += self.start_epoch self.model.train() if self.args.distributed: self.train_loader.sampler.set_epoch(epoch) if self.verbose: pbar = tqdm(total=len(self.train_loader), ncols=150) nlvr_results = np.zeros(4, dtype=int) epoch_results = { 'loss': 0, } quesid2ans = {} train_acc = 0. train_acc_steps = int(len(self.train_loader) * 0.05) last_acc_step = 0 for step_i, batch in enumerate(self.train_loader): self.optim.zero_grad() if self.args.fp16 and _use_native_amp: with autocast(): if self.args.distributed: results = self.model.module.train_step(batch) else: results = self.model.train_step(batch) else: if self.args.distributed: results = self.model.module.train_step(batch) else: results = self.model.train_step(batch) loss = results['loss'] # print(f'GPU{self.args.gpu} after loss') if self.args.fp16 and _use_native_amp: self.scaler.scale(loss).backward() elif self.args.fp16 and _use_apex: with amp.scale_loss(loss, self.optim) as scaled_loss: scaled_loss.backward() else: loss.backward() # print(f'GPU{self.args.gpu} after backward') loss = loss.detach() # Update Parameters if self.args.clip_grad_norm > 0: if self.args.fp16 and _use_native_amp: self.scaler.unscale_(self.optim) torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.args.clip_grad_norm) elif self.args.fp16 and _use_apex: torch.nn.utils.clip_grad_norm_(amp.master_params( self.optim), self.args.clip_grad_norm) else: torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.args.clip_grad_norm) if self.args.fp16 and _use_native_amp: self.scaler.step(self.optim) self.scaler.update() else: self.optim.step() if self.lr_scheduler: self.lr_scheduler.step() # self.model.zero_grad() for param in self.model.parameters(): param.grad = None global_step += 1 for k, v in results.items(): if k in epoch_results: epoch_results[k] += v.item() if self.lr_scheduler: if version.parse(torch.__version__) >= version.parse("1.4"): lr = self.lr_scheduler.get_last_lr()[0] else: lr = self.lr_scheduler.get_lr()[0] else: try: lr = self.optim.get_lr()[0] except AttributeError: lr = self.args.lr if self.verbose: loss_meter.update(loss.item()) desc_str = f'Epoch {epoch} | LR {lr:.6f} | ' desc_str += f'Loss {loss_meter.val:4f} |' pred_ans = results['pred_ans_id'] ques_ids = batch['question_ids'] for qid, ans in zip(ques_ids, pred_ans): quesid2ans[qid] = ans label = batch['labels'].cpu().numpy() predict = results['pred_ans_id'] nlvr_results[0] += sum((label == 1) & (predict == 1)) nlvr_results[1] += sum((label == 1) & (predict == 0)) nlvr_results[2] += sum((label == 0) & (predict == 1)) nlvr_results[3] += sum((label == 0) & (predict == 0)) n_total = sum(nlvr_results) desc_str += f' TP {nlvr_results[0]} ({nlvr_results[0]/n_total*100:.1f}%)' desc_str += f' FN {nlvr_results[1]} ({nlvr_results[1]/n_total*100:.1f}%)' desc_str += f' FP {nlvr_results[2]} ({nlvr_results[2]/n_total*100:.1f}%)' desc_str += f' TN {nlvr_results[3]} ({nlvr_results[3]/n_total*100:.1f}%)' pbar.set_description(desc_str) pbar.update(1) if self.args.distributed: dist.barrier() if self.verbose: pbar.close() log_str = '' # train_score_dict = self.train_loader.evaluator.evaluate(quesid2ans) # train_acc = train_score_dict['accuracy'] * 100. # train_cons = train_score_dict['consistency'] * 100. train_acc = self.train_loader.evaluator.evaluate_train(quesid2ans) * 100. train_score = train_acc log_str += "\nEpoch %d: Train %0.2f" % (epoch, train_score) # Validation valid_score_dict = self.evaluate(self.val_loader) valid_acc = valid_score_dict['accuracy'] * 100. # valid_cons = valid_score_dict['consistency'] * 100. valid_score = valid_acc if valid_score > best_valid: best_valid = valid_score best_epoch = epoch self.save("BEST") log_str += "\nEpoch %d: Valid %0.2f" % (epoch, valid_score) log_str += "\nEpoch %d: Best %0.2f\n" % (best_epoch, best_valid) wandb_log_dict = {} # wandb_log_dict['Train/Loss'] = loss_meter.val # wandb_log_dict['Train/score'] = score # wandb_log_dict['Valid/score'] = valid_score # for score_name, score in train_score_dict.items(): # wandb_log_dict[f'Train/{score_name}'] = score * 100. wandb_log_dict['Train/accuracy'] = train_acc for score_name, score in valid_score_dict.items(): wandb_log_dict[f'Valid/{score_name}'] = score * 100. wandb.log(wandb_log_dict, step=epoch) print(log_str) if self.args.distributed: dist.barrier() if self.verbose: self.save("LAST") # Test Set best_path = os.path.join(self.args.output, 'BEST') self.load(best_path) log_str = 'Test set results\n' dump_path = os.path.join(self.args.output, 'submit.csv') test_score_dict = self.evaluate(self.test_loader, dump_path=dump_path) wandb.save(dump_path, base_path=self.args.output) wandb_log_dict = {} for score_name, score in test_score_dict.items(): wandb_log_dict[f'Test/{score_name}'] = score * 100. wandb.log(wandb_log_dict, step=epoch) from pprint import pformat log_str += pformat(test_score_dict) print(log_str) wandb.log({'finished': True})
def train(self): if self.verbose: vqa_loss_meter = LossMeter() refcoco_loss_meter = LossMeter() # best_eval_loss = 9595. quesid2ans = {} best_vqa_valid = 0. best_vqa_epoch = 0 # gqa best_gqa_valid = 0 best_gqa_epoch = 0 # nlvr best_nlvr_valid = 0 best_nlvr_epoch = 0 # vcr best_valid_Q_AR = 0 best_vcr_epoch = 0 # refcoco best_refcoco_valid = 0 best_refcoco_epoch = 0 # caption best_caption_valid = 0 best_caption_epoch = 0 # mmt best_mmt_valid = 0 best_mmt_epoch = 0 assert 't5' in self.args.backbone self.setup_wandb() if self.args.distributed: dist.barrier() global_step = 0 for epoch in range(self.args.epochs): if self.start_epoch is not None: epoch += self.start_epoch self.model.train() if self.args.distributed: self.train_loader.set_epoch(epoch) if self.verbose: pbar = tqdm(total=len(self.train_loader), ncols=250) epoch_results = { 'loss': 0., } task_counter = { 'vqa': 0, 'gqa': 0, 'nlvr': 0, 'refcoco': 0, 'vcr': 0, 'caption': 0, 'mmt': 0, } # vqa quesid2ans = {} train_acc = 0. # train_acc_steps = int(len(self.train_loader) * 0.05) # last_acc_step = 0 # refcoco n_correct = 0 n_total = 0 for step_i, batch in enumerate(self.train_loader): # print(f'GPU{self.args.gpu} inside training loop') # print(batch) task = batch['task'] # if self.verbose: # print('task', task) task_counter[task] += 1 batch['log_train_accuracy'] = self.args.log_train_accuracy # self.optim.zero_grad() if self.args.fp16 and _use_native_amp: with autocast(): if self.args.distributed: results = self.model.module.train_step(batch) else: results = self.model.train_step(batch) else: if self.args.distributed: results = self.model.module.train_step(batch) else: results = self.model.train_step(batch) loss = results['loss'] # print(f'GPU{self.args.gpu} after loss') if self.args.fp16 and _use_native_amp: self.scaler.scale(loss).backward() elif self.args.fp16 and _use_apex: with amp.scale_loss(loss, self.optim) as scaled_loss: scaled_loss.backward() else: loss.backward() # print(f'GPU{self.args.gpu} after backward') loss = loss.detach() # Update Parameters if self.args.clip_grad_norm > 0: if self.args.fp16 and _use_native_amp: self.scaler.unscale_(self.optim) torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.args.clip_grad_norm) elif self.args.fp16 and _use_apex: torch.nn.utils.clip_grad_norm_( amp.master_params(self.optim), self.args.clip_grad_norm) else: torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.args.clip_grad_norm) if self.args.fp16 and _use_native_amp: self.scaler.step(self.optim) self.scaler.update() else: self.optim.step() if self.lr_scheduler: self.lr_scheduler.step() for param in self.model.parameters(): param.grad = None global_step += 1 for k, v in results.items(): if k in epoch_results: epoch_results[k] += v.item() if self.lr_scheduler: if version.parse( torch.__version__) >= version.parse("1.4"): lr = self.lr_scheduler.get_last_lr()[0] else: lr = self.lr_scheduler.get_lr()[0] else: try: lr = self.optim.get_lr()[0] except AttributeError: lr = self.args.lr # self.train_step_post_hook(result) if self.args.log_train_accuracy and task == 'refcoco': correct = results['correct'] n_correct += sum(correct) n_total += len(correct) if self.verbose: if task == 'vqa': vqa_loss_meter.update(loss.item()) elif task == 'refcoco': refcoco_loss_meter.update(loss.item()) desc_str = f'Epoch {epoch} | LR {lr:.6f}' desc_str += f" |" if 'vqa' in self.args.tasks: desc_str += f" VQA {task_counter['vqa']}" if 'gqa' in self.args.tasks: desc_str += f" GQA {task_counter['gqa']}" if 'nlvr' in self.args.tasks: desc_str += f" NLVR {task_counter['nlvr']}" if 'vcr' in self.args.tasks: desc_str += f" VCR {task_counter['vcr']}" if 'refcoco' in self.args.tasks: desc_str += f" RefCOCOg {task_counter['refcoco']}" if 'caption' in self.args.tasks: desc_str += f" COCO {task_counter['caption']}" if 'mmt' in self.args.tasks: desc_str += f" MMT {task_counter['mmt']}" if len(vqa_loss_meter) > 0: desc_str += f' | VQA Loss {vqa_loss_meter.val:4f}' if len(refcoco_loss_meter) > 0: desc_str += f' | RefCOCOg Loss {refcoco_loss_meter.val:.3f}' if self.args.log_train_accuracy and n_total > 0: desc_str += f' | RefCOCOg Acc' desc_str += f' Correct {n_correct:.0f}' desc_str += f' (Acc {n_correct/n_total*100:.1f}%)' pbar.set_description(desc_str) pbar.update(1) if self.args.distributed: dist.barrier() if self.verbose: pbar.close() self.save("Epoch%02d" % (epoch + 1)) if self.args.log_train_accuracy: train_score_dict = {'n_correct': n_correct, 'n_total': n_total} train_score_dict = reduce_dict(train_score_dict, self.args.gpu) if self.verbose: # Validation log_str = '' wandb_log_dict = {} if 'vqa' in self.args.tasks: # VQA vqa_val_loader = self.val_loader['vqa'] score_dict = self.vqa_evaluate(vqa_val_loader) valid_score = score_dict['topk_score'] * 100. valid_score_raw = score_dict['overall'] if valid_score_raw > best_vqa_valid or epoch == 0: best_vqa_valid = valid_score_raw best_vqa_epoch = epoch # self.save("VQA_BEST") log_str += f"VQA" log_str += "\nEpoch %d: Valid Raw %0.2f Topk %0.2f" % ( epoch, valid_score_raw, valid_score) log_str += "\nEpoch %d: Best Raw %0.2f\n" % ( best_vqa_epoch, best_vqa_valid) wandb_log_dict['VQA/Valid/score'] = valid_score wandb_log_dict['VQA/Valid/raw_score'] = score_dict[ 'overall'] if 'gqa' in self.args.tasks: # GQA gqa_val_loader = self.val_loader['gqa'] valid_score = self.gqa_evaluate(gqa_val_loader) * 100 if valid_score > best_gqa_valid or epoch == 0: best_gqa_valid = valid_score best_gqa_epoch = epoch wandb_log_dict['GQA/Valid/Acc'] = valid_score log_str += f"GQA" log_str += "\nEpoch %d: Valid %0.2f" % (epoch, valid_score) log_str += "\nEpoch %d: Best %0.2f\n" % (best_gqa_epoch, best_gqa_valid) if 'nlvr' in self.args.tasks: # NLVR nlvr_val_loader = self.val_loader['nlvr'] valid_score_dict = self.nlvr_evaluate(nlvr_val_loader) valid_acc = valid_score_dict['accuracy'] * 100. if valid_acc > best_nlvr_valid or epoch == 0: best_nlvr_valid = valid_acc best_nlvr_epoch = epoch wandb_log_dict['NLVR/Valid/Acc'] = valid_acc log_str += f"NLVR" log_str += "\nEpoch %d: Valid %0.2f" % (epoch, valid_acc) log_str += "\nEpoch %d: Best %0.2f\n" % (best_nlvr_epoch, best_nlvr_valid) if 'vcr' in self.args.tasks: # VCR vcr_val_loader = self.val_loader['vcr'] valid_score_dict = self.vcr_evaluate(vcr_val_loader) valid_Q_A = valid_score_dict['Q_A'] / valid_score_dict[ 'n_total'] * 100 valid_QA_R = valid_score_dict['QA_R'] / valid_score_dict[ 'n_total'] * 100 valid_Q_AR = valid_score_dict['Q_AR'] / valid_score_dict[ 'n_total'] * 100 valid_n_total = int(valid_score_dict['n_total']) if valid_Q_AR > best_valid_Q_AR or epoch == 0: best_valid_Q_AR = valid_Q_AR best_vcr_epoch = epoch wandb_log_dict['VCR/Valid/Q_A'] = valid_Q_A wandb_log_dict['VCR/Valid/QA_R'] = valid_QA_R wandb_log_dict['VCR/Valid/Q_AR'] = valid_Q_AR log_str += f"VCR" log_str += "\nEpoch %d: Valid %0.2f" % (epoch, valid_Q_AR) log_str += "\nEpoch %d: Best %0.2f\n" % (best_vcr_epoch, best_valid_Q_AR) if 'refcoco' in self.args.tasks: # RefCOCO refcoco_val_loader = self.val_loader['refcoco'] if self.args.log_train_accuracy: train_acc = train_score_dict[ 'n_correct'] / train_score_dict['n_total'] * 100 train_n_correct = int(train_score_dict['n_correct']) train_n_total = int(train_score_dict['n_total']) valid_score_dict = self.refcoco_evaluate( refcoco_val_loader) valid_acc = valid_score_dict[ 'n_correct'] / valid_score_dict['n_total'] * 100 valid_n_correct = int(valid_score_dict['n_correct']) valid_n_total = int(valid_score_dict['n_total']) if valid_acc > best_refcoco_valid or epoch == 0: best_refcoco_valid = valid_acc best_refcoco_epoch = epoch if self.args.log_train_accuracy: wandb_log_dict['RefCOCO/Train/Acc'] = train_acc wandb_log_dict['RefCOCO/Valid/Acc'] = valid_acc log_str += f"RefCOCOg" if self.args.log_train_accuracy: log_str += f"\nEpoch {epoch}: Train" log_str += f" Acc {train_acc:.2f}% |" log_str += f" # correct {train_n_correct} # total {train_n_total}" log_str += f"\nEpoch {epoch}: Valid" log_str += f" Acc {valid_acc:.2f}% |" log_str += f" # correct {valid_n_correct} # total {valid_n_total}" log_str += f"\nEpoch {best_refcoco_epoch}: Best Acc {best_refcoco_valid:.2f}%\n" if 'caption' in self.args.tasks: # COCO Caption caption_val_loader = self.val_loader['caption'] valid_results = self.caption_evaluate(caption_val_loader) valid_score = valid_results['CIDEr'] * 100 if valid_score > best_caption_valid or epoch == 0: best_caption_valid = valid_score best_caption_epoch = epoch for score_name, score in valid_results.items(): wandb_log_dict[ f'Caption/Valid/{score_name}'] = score * 100 log_str += f"COCO Caption" log_str += "\nEpoch %d: Valid CIDEr %0.2f" % (epoch, valid_score) log_str += "\nEpoch %d: Best %0.2f\n" % ( best_caption_epoch, best_caption_valid) if 'mmt' in self.args.tasks: # MMT mmt_val_loader = self.val_loader['mmt'] valid_results = self.mmt_evaluate(mmt_val_loader) valid_score = valid_results['BLEU'] if valid_score > best_mmt_valid: best_mmt_valid = valid_score best_mmt_epoch = epoch for score_name, score in valid_results.items(): wandb_log_dict[f'MMT/Valid/{score_name}'] = score log_str += f"Multi30K En-De" log_str += "\nEpoch %d: Valid BLEU %0.2f" % (epoch, valid_score) log_str += "\nEpoch %d: Best %0.2f\n" % (best_mmt_epoch, best_mmt_valid) wandb.log(wandb_log_dict, step=epoch) print(log_str) if self.args.distributed: dist.barrier() # Test Set if self.verbose: self.save("LAST") log_str = '' wandb_log_dict = {} if 'vqa' in self.args.tasks: # VQA vqa_test_loader = self.test_loader['vqa'] evaluator = vqa_test_loader.evaluator dump_path = os.path.join(self.args.output, 'karpathy_test_predict.json') quesid2ans = self.vqa_predict(vqa_test_loader, dump_path) wandb.save(dump_path, base_path=self.args.output) acc_dict_all = evaluator.evaluate_raw(quesid2ans) acc_dict_answerable = evaluator.evaluate_raw( quesid2ans, is_topk_optimal=True) acc_dict_unanswerable = evaluator.evaluate_raw( quesid2ans, is_topk_optimal=False) wandb_log_dict['VQA/Test/overall'] = acc_dict_all['overall'] wandb_log_dict['VQA/Test/topk_optimal'] = acc_dict_answerable[ 'overall'] wandb_log_dict[ 'VQA/Test/topk_not_optimal'] = acc_dict_unanswerable[ 'overall'] vqa_submit_test_loader = self.test_loader['vqa_submit'] dump_path = os.path.join(self.args.output, 'vqa_submit.json') self.vqa_predict(vqa_submit_test_loader, dump_path=dump_path) wandb.save(dump_path, base_path=self.args.output) if 'nlvr' in self.args.tasks: # NLVR nlvr_test_loader = self.test_loader['nlvr'] dump_path = os.path.join(self.args.output, 'nlvr_submit.csv') test_score_dict = self.nlvr_evaluate(nlvr_test_loader, dump_path=dump_path) wandb.save(dump_path, base_path=self.args.output) for score_name, score in test_score_dict.items(): wandb_log_dict[f'NLVR/Test/{score_name}'] = score * 100. if 'refcoco' in self.args.tasks: # RefCOCO refcoco_test_loader = self.test_loader['refcoco'] test_score_dict = self.refcoco_evaluate(refcoco_test_loader) test_acc = test_score_dict['n_correct'] / test_score_dict[ 'n_total'] * 100 test_n_correct = int(test_score_dict['n_correct']) test_n_total = int(test_score_dict['n_total']) wandb_log_dict['RefCOCO/test/Acc'] = test_acc log_str = 'RefCOCOg' log_str += f"\nTest Acc {test_acc:.2f}%" log_str += f"\nTest # correct {test_n_correct} # total {test_n_total}" if 'caption' in self.args.tasks: # COCO Caption caption_test_loader = self.test_loader['caption'] test_results = self.caption_evaluate(caption_test_loader) for score_name, score in test_results.items(): wandb_log_dict[f'Caption/Test/{score_name}'] = score if 'mmt' in self.args.tasks: # MMT mmt_test2016_loader = self.test_loader['mmt_test2016'] mmt_test2017_loader = self.test_loader['mmt_test2017'] mmt_test2018_loader = self.test_loader['mmt_test2018'] for loader in [ mmt_test2016_loader, mmt_test2017_loader, mmt_test2018_loader ]: split = loader.dataset.source dump_path = os.path.join(self.args.output, f'submit_{split}_raw.txt') test_results = self.mmt_evaluate(loader, dump_path=dump_path) for score_name, score in test_results.items(): wandb_log_dict[f'MMT/{split}/{score_name}'] = score log_str += f'{split} set results\n' log_str += pformat(test_results) print(log_str) wandb.log(wandb_log_dict, step=self.args.epochs) wandb.log({'finished': True}) if self.args.distributed: dist.barrier() exit()
def train(self): if self.verbose: loss_meter = LossMeter() best_valid = 0. best_epoch = 0 if not self.wandb_initialized: if 't5' in self.args.backbone: project_name = "VLT5_COCOCaption" elif 'bart' in self.args.backbone: project_name = "VLBart_COCOCaption" wandb.init(project=project_name) wandb.run.name = self.args.run_name wandb.config.update(self.args) wandb.watch(self.model) src_dir = Path(__file__).resolve().parent base_path = str(src_dir.parent) src_dir = str(src_dir) wandb.save(os.path.join(src_dir + "/*.py"), base_path=base_path) self.wandb_initialized = True if self.args.distributed: dist.barrier() global_step = 0 epochs = self.args.epochs for epoch in range(epochs): if self.start_epoch is not None: epoch += self.start_epoch self.model.train() if self.args.distributed: self.train_loader.sampler.set_epoch(epoch) if self.verbose: pbar = tqdm(total=len(self.train_loader), ncols=120) epoch_results = { 'loss': 0., } for step_i, batch in enumerate(self.train_loader): if self.args.fp16 and _use_native_amp: with autocast(): if self.args.distributed: results = self.model.module.train_step(batch) else: results = self.model.train_step(batch) else: if self.args.distributed: results = self.model.module.train_step(batch) else: results = self.model.train_step(batch) loss = results['loss'] if self.args.fp16 and _use_native_amp: self.scaler.scale(loss).backward() elif self.args.fp16 and _use_apex: with amp.scale_loss(loss, self.optim) as scaled_loss: scaled_loss.backward() else: loss.backward() loss = loss.detach() # Update Parameters if self.args.clip_grad_norm > 0: if self.args.fp16 and _use_native_amp: self.scaler.unscale_(self.optim) torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.args.clip_grad_norm) elif self.args.fp16 and _use_apex: torch.nn.utils.clip_grad_norm_( amp.master_params(self.optim), self.args.clip_grad_norm) else: torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.args.clip_grad_norm) update = True if self.args.gradient_accumulation_steps > 1: if step_i == 0: update = False elif step_i % self.args.gradient_accumulation_steps == 0 or step_i == len( self.train_loader) - 1: update = True else: update = False if update: if self.args.fp16 and _use_native_amp: self.scaler.step(self.optim) self.scaler.update() else: self.optim.step() if self.lr_scheduler: self.lr_scheduler.step() # self.model.zero_grad() for param in self.model.parameters(): param.grad = None global_step += 1 for k, v in results.items(): if k in epoch_results: epoch_results[k] += v.item() if self.lr_scheduler: if version.parse( torch.__version__) >= version.parse("1.4"): lr = self.lr_scheduler.get_last_lr()[0] else: lr = self.lr_scheduler.get_lr()[0] else: try: lr = self.optim.get_lr()[0] except AttributeError: lr = self.args.lr if self.verbose: loss_meter.update(loss.item()) desc_str = f'Epoch {epoch} | LR {lr:.6f} | Steps {global_step}' desc_str += f' | Loss {loss_meter.val:4f}' pbar.set_description(desc_str) pbar.update(1) if self.args.distributed: dist.barrier() if self.verbose: pbar.close() # format ex) # {'Bleu_1': 0.9999999997500004, # 'Bleu_2': 0.5773502690332603, # 'Bleu_3': 4.3679023223468616e-06, # 'Bleu_4': 1.4287202142987477e-08, # 'CIDEr': 3.333333333333333, # 'METEOR': 0.43354749322305886, # 'ROUGE_L': 0.75, # 'SPICE': 0.6666666666666666} # Validation valid_results = self.evaluate(self.val_loader) if self.verbose: valid_score = valid_results['CIDEr'] if valid_score > best_valid or epoch == 0: best_valid = valid_score best_epoch = epoch self.save("BEST") log_str = '' log_str += pformat(valid_results) log_str += "\nEpoch %d: Valid CIDEr %0.4f" % (epoch, valid_score) log_str += "\nEpoch %d: Best CIDEr %0.4f\n" % (best_epoch, best_valid) wandb_log_dict = {} wandb_log_dict['Train/Loss'] = epoch_results['loss'] / len( self.train_loader) for score_name, score in valid_results.items(): wandb_log_dict[f'Valid/{score_name}'] = score wandb_log_dict[f'Valid/best_epoch'] = best_epoch wandb.log(wandb_log_dict, step=epoch) print(log_str) if self.args.distributed: dist.barrier() if self.verbose: self.save("LAST") # Test Set best_path = os.path.join(self.args.output, 'BEST') self.load(best_path) if self.verbose: wandb.save(best_path, base_path=self.args.output) print(f'\nUploaded checkpoint {best_epoch}', best_path) test_results = self.evaluate(self.test_loader) if self.verbose: wandb_log_dict = {} for score_name, score in test_results.items(): wandb_log_dict[f'Test/{score_name}'] = score wandb.log(wandb_log_dict, step=epoch) log_str = 'Test set results\n' log_str += pformat(test_results) print(log_str) if self.args.distributed: dist.barrier()
def train(self): LOSSES_NAME = self.args.LOSSES_NAME task_dict = { 'Mask_LM': 'word_mask', 'Matched': 'matched', 'Mask_Obj': 'vis_mask', 'Mask_Attr': 'vis_mask', 'Mask_Feat': 'vis_mask', 'QA': 'qa' } if self.args.dry: results = self.evaluate_epoch(epoch=0) self.optim.zero_grad() if self.verbose: loss_meters = [LossMeter() for _ in range(len(LOSSES_NAME))] best_eval_loss = 9595. from torch.utils.tensorboard import SummaryWriter self.writer = SummaryWriter(log_dir=self.args.log_dir) print('logging at', str(self.args.log_dir)) self.logger.info('logging at' + str(self.args.log_dir)) hparam_dict = {} for k, v in self.args.__dict__.items(): if type(v) in [int, float, str, bool, torch.Tensor]: hparam_dict[k] = v metric_dict = {} self.writer.add_hparams(hparam_dict, metric_dict) dist.barrier() n_update = 0 global_step = 0 for epoch in range(self.args.epochs): if self.start_epoch is not None: epoch += self.start_epoch if self.args.distributed: self.train_loader.sampler.set_epoch(epoch) # Train self.model.train() loss_counts = [0 for _ in range(len(LOSSES_NAME))] if self.verbose: pbar = tqdm(total=len(self.train_loader), ncols=240) epoch_results = { 'lm_loss': 0, 'vis_loss': 0, 'matched_loss': 0, 'qa_loss': 0, 'obj_loss': 0, 'feat_loss': 0, 'attr_loss': 0, } for k in list(epoch_results.keys()): if k[-4:] == 'loss': epoch_results[f'{k}_count'] = 0 if self.args.task_qa: uid2ans = {} for step_i, batch in enumerate(self.train_loader): # task = random.choice(self.args.MASK_MODALITY) task_i = step_i % len(self.args.MASK_MODALITY) task = self.args.MASK_MODALITY[task_i] # with torch.autograd.set_detect_anomaly(True): results = self.forward(batch, task) if self.args.fp16 and _use_native_amp: with autocast(): results = self.model(batch, task) else: results = self.model(batch, task) if task == 'vis_mask': if 'Mask_Obj' in LOSSES_NAME: epoch_results['obj_loss_count'] += 1 if 'Mask_Feat' in LOSSES_NAME: epoch_results['feat_loss_count'] += 1 if 'Mask_Attr' in LOSSES_NAME: epoch_results['attr_loss_count'] += 1 epoch_results['vis_loss_count'] += 1 elif task == 'word_mask': epoch_results['lm_loss_count'] += 1 elif task == 'matched': epoch_results['matched_loss_count'] += 1 if self.args.task_qa: epoch_results['qa_loss_count'] += 1 qa_pred = results['qa_pred'] for uid, ans_id in zip(batch['uid'], qa_pred.cpu().numpy()): ans = self.train_loader.dataset.answer_table.id2ans( ans_id) uid2ans[uid] = ans loss = results['total_loss'] #===== Update =====# if self.args.fp16 and _use_native_amp: self.scaler.scale(loss).backward() elif self.args.fp16 and _use_apex: with amp.scale_loss(loss, self.optim) as scaled_loss: scaled_loss.backward() else: loss.backward() loss = loss.detach() # Update Parameters if self.args.clip_grad_norm > 0: if self.args.fp16 and _use_native_amp: self.scaler.unscale_(self.optim) torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.args.clip_grad_norm) elif self.args.fp16 and _use_apex: torch.nn.utils.clip_grad_norm_( amp.master_params(self.optim), self.args.clip_grad_norm) else: torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.args.clip_grad_norm) if self.args.fp16 and _use_native_amp: self.scaler.step(self.optim) self.scaler.update() else: self.optim.step() if self.lr_scheduler: self.lr_scheduler.step() for param in self.model.parameters(): param.grad = None global_step += 1 #====================# try: lr = self.scheduler.get_last_lr()[0] except AttributeError: lr = self.args.lr if self.verbose: desc_str = f'Epoch {epoch} | LR {lr:.6f} | ' if self.args.word_mask_predict: desc_str += f'Word Mask: Uniform (MP) | ' elif self.args.word_mask_rate > 0: desc_str += f'Word Mask: {self.args.word_mask_rate:.2f} | ' if self.args.vis_mask_predict: desc_str += f'Vis Mask: Uniform (MP) |' else: desc_str += f'Vis Mask: {self.args.obj_mask_rate:.2f} |' if self.args.task_qa: loss_meter = loss_meters[-1] loss_meter.update(results['qa_loss'].item()) loss_counts[-1] += 1 for i, (loss_name, loss_meter) in enumerate( zip(LOSSES_NAME, loss_meters)): if task_dict[loss_name] == task: if task == 'vis_mask': if loss_name == 'Mask_Obj': loss_meter.update( results['obj_loss'].item()) elif loss_name == 'Mask_Attr': loss_meter.update( results['attr_loss'].item()) elif loss_name == 'Mask_Feat': loss_meter.update( results['feat_loss'].item()) elif task == 'word_mask': loss_meter.update(results['lm_loss'].item()) elif task == 'matched': loss_meter.update( results['matched_loss'].item()) # elif task == 'qa': # loss_meter.update(results['qa_loss'].item()) loss_counts[i] += 1 if len(loss_meter) > 0: loss_count = loss_counts[i] if loss_name in [ 'Mask_LM', 'Matched', 'Mask_Obj', 'Mask_Attr', 'Mask_Feat', 'QA' ]: desc_str += f' {loss_name} ({loss_count}) {loss_meter.val:.3f}' else: desc_str += f' {loss_name} {loss_meter.val:.3f}' if step_i % 10 == 0: self.writer.add_scalar( f'Train_steps/{loss_name}', loss_meter.val, global_step) # if update: n_update += 1 desc_str += f' | Total Update: {n_update}' pbar.set_description(desc_str) pbar.update(1) if self.verbose: pbar.close() dist.barrier() results = reduce_dict(epoch_results, self.args.gpu) if self.args.gpu == 0: total_loss = results['lm_loss'] + results[ 'vis_loss'] + results['matched_loss'] + results['qa_loss'] total_count = results['lm_loss_count'] + results[ 'vis_loss_count'] + results['matched_loss_count'] # + results['qa_loss_count'] avg_train_loss = total_loss / total_count losses_str = f"Train Loss: {avg_train_loss:.4f}\n" for name, loss in results.items(): if name[-4:] == 'loss': loss_count = int(results[name + '_count']) if loss_count > 0: avg_loss = loss / loss_count if name == 'lm_loss': name = 'Mask_LM' elif name == 'matched_loss': name = 'Matched' elif name == 'obj_loss': name = 'Mask_Obj' elif name == 'attr_loss': name = 'Mask_Attr' elif name == 'feat_loss': name = 'Mask_Feat' elif name == 'qa_loss': name = 'QA' losses_str += f"{name} ({loss_count}): {avg_loss:.4f} " self.writer.add_scalar(f'Train Loss/{name}', avg_loss, epoch) losses_str += '\n' print(losses_str) self.logger.info(losses_str) if self.args.task_qa: dset2score, dset2cnt, score, cnt = self.train_loader.dataset.evaluator.evaluate( uid2ans) dset2score = reduce_dict(dset2score, self.args.gpu) dset2cnt = reduce_dict(dset2cnt, self.args.gpu) score_cnt_dict = reduce_dict({ 'score': score, 'cnt': cnt }, self.args.gpu) if self.args.gpu == 0: score = score_cnt_dict['score'] cnt = score_cnt_dict['cnt'] accu = score / cnt dset2accu = {} for dset in dset2cnt: dset2accu[dset] = dset2score[dset] / dset2cnt[dset] accu_str = "Overall Accu %0.4f, " % (accu) sorted_keys = sorted(dset2accu.keys()) for key in sorted_keys: accu_str += "%s Accu %0.4f, " % (key, dset2accu[key]) print(accu_str) self.logger.info(accu_str) dist.barrier() # Validation valid_results, valid_uid2ans = self.evaluate_epoch(epoch=epoch) valid_results = reduce_dict(valid_results, self.args.gpu) if self.args.gpu == 0: valid_total_loss = valid_results['lm_loss'] + valid_results[ 'vis_loss'] + valid_results[ 'matched_loss'] + valid_results['qa_loss'] valid_total_count = valid_results[ 'lm_loss_count'] + valid_results[ 'vis_loss_count'] + valid_results['matched_loss_count'] # + valid_results['qa_loss_count'] avg_valid_loss = valid_total_loss / valid_total_count losses_str = f"Valid Loss: {avg_valid_loss:.4f}\n" for name, loss in valid_results.items(): if name[-4:] == 'loss': loss_count = int(valid_results[name + '_count']) if loss_count > 0: avg_loss = loss / loss_count if name == 'lm_loss': name = 'Mask_LM' elif name == 'matched_loss': name = 'Matched' elif name == 'obj_loss': name = 'Mask_Obj' elif name == 'attr_loss': name = 'Mask_Attr' elif name == 'feat_loss': name = 'Mask_Feat' elif name == 'qa_loss': name = 'QA' losses_str += f"{name} ({loss_count}): {avg_loss:.4f} " self.writer.add_scalar(f'Valid Loss/{name}', avg_loss, epoch) losses_str += '\n' print(losses_str) self.logger.info(losses_str) if self.args.task_qa: dset2score, dset2cnt, score, cnt = self.val_loader.dataset.evaluator.evaluate( valid_uid2ans) dset2score = reduce_dict(dset2score, self.args.gpu) dset2cnt = reduce_dict(dset2cnt, self.args.gpu) score_cnt_dict = reduce_dict({ 'score': score, 'cnt': cnt }, self.args.gpu) if self.args.gpu == 0: score = score_cnt_dict['score'] cnt = score_cnt_dict['cnt'] accu = score / cnt dset2accu = {} for dset in dset2cnt: dset2accu[dset] = dset2score[dset] / dset2cnt[dset] accu_str = "Overall Accu %0.4f, " % (accu) sorted_keys = sorted(dset2accu.keys()) for key in sorted_keys: accu_str += "%s Accu %0.4f, " % (key, dset2accu[key]) print(accu_str) self.logger.info(accu_str) dist.barrier() if self.verbose: # Save if avg_valid_loss < best_eval_loss: best_eval_loss = avg_valid_loss # self.save("BEST_EVAL_LOSS") self.save("Epoch%02d" % (epoch + 1)) dist.barrier()
def train(self): LOSSES_NAME = self.args.LOSSES_NAME if self.args.dry: results = self.evaluate_epoch(epoch=0) if self.verbose: loss_meters = [LossMeter() for _ in range(len(LOSSES_NAME))] best_eval_loss = 9595. if 't5' in self.args.backbone: project_name = "VLT5_Pretrain" elif 'bart' in self.args.backbone: project_name = "VLBart_Pretrain" wandb.init(project=project_name) wandb.run.name = self.args.run_name wandb.config.update(self.args) wandb.watch(self.model) src_dir = Path(__file__).resolve().parent base_path = str(src_dir.parent) src_dir = str(src_dir) wandb.save(os.path.join(src_dir + "/*.py"), base_path=base_path) if self.args.distributed: dist.barrier() global_step = 0 for epoch in range(self.args.epochs): if self.start_epoch is not None: epoch += self.start_epoch if self.args.distributed: self.train_loader.sampler.set_epoch(epoch) # Train self.model.train() if self.verbose: pbar = tqdm(total=len(self.train_loader), ncols=250) epoch_results = {} for loss_name in LOSSES_NAME: epoch_results[loss_name] = 0. epoch_results[f'{loss_name}_count'] = 0 for step_i, batch in enumerate(self.train_loader): if self.args.fp16 and _use_native_amp: with autocast(): if self.args.distributed: results = self.model.module.train_step(batch) else: results = self.model.train_step(batch) else: if self.args.distributed: results = self.model.module.train_step(batch) else: results = self.model.train_step(batch) loss = results['loss'] if self.args.fp16 and _use_native_amp: self.scaler.scale(loss).backward() elif self.args.fp16 and _use_apex: with amp.scale_loss(loss, self.optim) as scaled_loss: scaled_loss.backward() else: loss.backward() loss = loss.detach() # Update Parameters if self.args.clip_grad_norm > 0: if self.args.fp16 and _use_native_amp: self.scaler.unscale_(self.optim) torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip_grad_norm) elif self.args.fp16 and _use_apex: torch.nn.utils.clip_grad_norm_(amp.master_params(self.optim), self.args.clip_grad_norm) else: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip_grad_norm) if self.args.fp16 and _use_native_amp: self.scaler.step(self.optim) self.scaler.update() else: self.optim.step() if self.lr_scheduler: self.lr_scheduler.step() # self.model.zero_grad() for param in self.model.parameters(): param.grad = None global_step += 1 if self.lr_scheduler: if version.parse(torch.__version__) >= version.parse("1.4"): lr = self.lr_scheduler.get_last_lr()[0] else: lr = self.lr_scheduler.get_lr()[0] else: try: lr = self.optim.get_lr()[0] except AttributeError: lr = self.args.lr for k, v in results.items(): if k in epoch_results: if isinstance(v, int): epoch_results[k] += v elif isinstance(v, torch.Tensor): epoch_results[k] += v.item() if self.verbose: desc_str = f'Epoch {epoch} | LR {lr:.6f} |' for i, (loss_name, loss_meter) in enumerate(zip(LOSSES_NAME, loss_meters)): if loss_name in results: loss_meter.update(results[f'{loss_name}'] / results[f'{loss_name}_count']) if len(loss_meter) > 0: loss_count = epoch_results[f'{loss_name}_count'] desc_str += f' {loss_name} ({loss_count}) {loss_meter.val:.3f}' pbar.set_description(desc_str) pbar.update(1) if self.verbose: pbar.close() dist.barrier() results = reduce_dict(epoch_results, average=False) if self.verbose: train_loss = results['total_loss'] train_loss_count = results['total_loss_count'] avg_train_loss = train_loss / train_loss_count losses_str = f"Train Loss: {avg_train_loss:.3f}\n" for name, loss in results.items(): if name[-4:] == 'loss': loss_count = int(results[name+'_count']) if loss_count > 0: avg_loss = loss/loss_count losses_str += f"{name} ({loss_count}): {avg_loss:.3f} " wandb.log({f'Train Loss/{name}': avg_loss}, step=epoch) losses_str += '\n' print(losses_str) dist.barrier() # Validation valid_results, valid_uid2ans = self.evaluate_epoch(epoch=epoch) valid_results = reduce_dict(valid_results, average=False) if self.verbose: valid_loss = valid_results['total_loss'] valid_loss_count = valid_results['total_loss_count'] avg_valid_loss = valid_loss / valid_loss_count losses_str = f"Valid Loss: {avg_valid_loss:.3f}\n" for name, loss in valid_results.items(): if name[-4:] == 'loss': loss_count = int(valid_results[name+'_count']) if loss_count > 0: avg_loss = loss / loss_count losses_str += f"{name} ({loss_count}): {avg_loss:.3f} " wandb.log({f'Valid Loss/{name}': avg_loss}, step=epoch) losses_str += '\n' print(losses_str) if 'qa' in self.args.losses: dset2score, dset2cnt, score, cnt = self.val_loader.dataset.evaluator.evaluate(valid_uid2ans) if len(dset2score) == 0: dset2score = {'vqa': 0, 'gqa': 0, 'visual7w': 0} dset2cnt = {'vqa': 1, 'gqa': 1, 'visual7w': 1} cnt = 3 score = 0 dset2score = reduce_dict(dset2score, average=False) dset2cnt = reduce_dict(dset2cnt, average=False) score_cnt_dict = reduce_dict({'score': score, 'cnt': cnt}, average=False) if self.args.gpu == 0: score = score_cnt_dict['score'] cnt = score_cnt_dict['cnt'] accu = score / cnt dset2accu = {} for dset in dset2cnt: dset2accu[dset] = dset2score[dset] / dset2cnt[dset] accu_str = "Overall QA Acc %0.4f" % (accu) wandb.log({f'Valid QA Acc/Overall': accu}, step=epoch) sorted_keys = sorted(dset2accu.keys()) for key in sorted_keys: accu_str += ", %s Acc %0.4f" % (key, dset2accu[key]) wandb.log({f'Valid QA Acc/{key}': dset2accu[key]}, step=epoch) print(accu_str) accu_str += '\n\n' dist.barrier() if self.verbose: # Save if avg_valid_loss < best_eval_loss: best_eval_loss = avg_valid_loss # self.save("BEST_EVAL_LOSS") self.save("Epoch%02d" % (epoch + 1)) dist.barrier() if self.verbose: wandb.log({'finished': True})
def evaluate_epoch(self, epoch): LOSSES_NAME = self.args.LOSSES_NAME task_dict = { 'Mask_LM': 'word_mask', 'Matched': 'matched', 'Mask_Obj': 'vis_mask', 'Mask_Attr': 'vis_mask', 'Mask_Feat': 'vis_mask', 'QA': 'qa' } epoch_results = { 'lm_loss': 0, 'vis_loss': 0, 'matched_loss': 0, 'qa_loss': 0, 'obj_loss': 0, 'feat_loss': 0, 'attr_loss': 0, } for k in list(epoch_results.keys()): if k[-4:] == 'loss': epoch_results[f'{k}_count'] = 0 uid2ans = {} self.model.eval() with torch.no_grad(): if self.verbose: loss_meter = LossMeter() loss_meters = [LossMeter() for _ in range(len(LOSSES_NAME))] loss_counts = [0 for _ in range(len(LOSSES_NAME))] pbar = tqdm(total=len(self.val_loader), ncols=180) for step_i, batch in enumerate(self.val_loader): # task = random.choice(self.args.MASK_MODALITY) task_i = step_i % len(self.args.MASK_MODALITY) task = self.args.MASK_MODALITY[task_i] if self.args.vis_mask_COCO_only or self.args.vis_mask_COCOVG_only: if task == 'vis_mask': batch['word_id'] = batch['COCO_word_id'] if self.args.clustering: batch['cluster_id'] = batch['COCO_cluster_id'] results = self.forward(batch, task) if task == 'vis_mask': epoch_results['vis_loss_count'] += 1 if 'Mask_Obj' in LOSSES_NAME: epoch_results['obj_loss_count'] += 1 if 'Mask_Feat' in LOSSES_NAME: epoch_results['feat_loss_count'] += 1 if 'Mask_Attr' in LOSSES_NAME: epoch_results['attr_loss_count'] += 1 elif task == 'word_mask': epoch_results['lm_loss_count'] += 1 elif task == 'matched': epoch_results['matched_loss_count'] += 1 elif task == 'qa': epoch_results['qa_loss_count'] += 1 if self.args.task_qa: qa_pred = results['qa_pred'] for uid, ans_id in zip(batch['uid'], qa_pred.cpu().numpy()): ans = self.train_loader.dataset.answer_table.id2ans( ans_id) uid2ans[uid] = ans for k, v in results.items(): if k in epoch_results: epoch_results[k] += v.item() if self.verbose: desc_str = f'Valid Epoch {epoch} | ' # if self.args.task_qa: # loss_meter.update(results['qa_loss'].item()) if self.args.task_qa: loss_meter = loss_meters[-1] loss_meter.update(results['qa_loss'].item()) loss_counts[-1] += 1 for i, (loss_name, loss_meter) in enumerate( zip(LOSSES_NAME, loss_meters)): if task_dict[loss_name] == task: if task == 'vis_mask': if loss_name == 'Mask_Obj': loss_meter.update( results['obj_loss'].item()) elif loss_name == 'Mask_Attr': loss_meter.update( results['attr_loss'].item()) elif loss_name == 'Mask_Feat': loss_meter.update( results['feat_loss'].item()) elif task == 'word_mask': loss_meter.update(results['lm_loss'].item()) elif task == 'matched': loss_meter.update( results['matched_loss'].item()) # elif task == 'qa': # loss_meter.update(results['qa_loss'].item()) loss_counts[i] += 1 if len(loss_meter) > 0: loss_count = loss_counts[i] if loss_name in [ 'Mask_LM', 'Matched', 'Mask_Obj', 'Mask_Attr', 'Mask_Feat', 'QA' ]: desc_str += f' {loss_name} ({loss_count}) {loss_meter.val:.2f}' else: desc_str += f' {loss_name} {loss_meter.val:.2f}' pbar.set_description(desc_str) pbar.update(1) dist.barrier() if self.verbose: pbar.close() dist.barrier() if not self.args.task_qa: uid2ans = None return epoch_results, uid2ans
def train(cont=False): # for tensorboard tracking logger = get_logger() logger.info("(1) Initiating Training ... ") logger.info("Training on device: {}".format(device)) writer = SummaryWriter() # init model aux_layers = None if net == "SETR-PUP": aux_layers, model = get_SETR_PUP() elif net == "SETR-MLA": aux_layers, model = get_SETR_MLA() elif net == "TransUNet-Base": model = get_TransUNet_base() elif net == "TransUNet-Large": model = get_TransUNet_large() elif net == "UNet": model = UNet(CLASS_NUM) # prepare dataset cluster_model = get_clustering_model(logger) train_dataset = CityscapeDataset(img_dir=data_dir, img_dim=IMG_DIM, mode="train", cluster_model=cluster_model) valid_dataset = CityscapeDataset(img_dir=data_dir, img_dim=IMG_DIM, mode="val", cluster_model=cluster_model) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False) logger.info("(2) Dataset Initiated. ") # optimizer epochs = epoch_num if epoch_num > 0 else iteration_num // len( train_loader) + 1 optim = SGD(model.parameters(), lr=lrate, momentum=momentum, weight_decay=wdecay) # optim = Adam(model.parameters(), lr=lrate) scheduler = lr_scheduler.MultiStepLR( optim, milestones=[int(epochs * fine_tune_ratio)], gamma=0.1) cur_epoch = 0 best_loss = float('inf') epochs_since_improvement = 0 # for continue training if cont: model, optim, cur_epoch, best_loss = load_ckpt_continue_training( best_ckpt_src, model, optim, logger) logger.info("Current best loss: {0}".format(best_loss)) with warnings.catch_warnings(): warnings.simplefilter("ignore") for i in range(cur_epoch): scheduler.step() else: model = nn.DataParallel(model) model = model.to(device) logger.info("(3) Model Initiated ... ") logger.info("Training model: {}".format(net) + ". Training Started.") # loss ce_loss = CrossEntropyLoss() if use_dice_loss: dice_loss = DiceLoss(CLASS_NUM) # loop over epochs iter_count = 0 epoch_bar = tqdm.tqdm(total=epochs, desc="Epoch", position=cur_epoch, leave=True) logger.info("Total epochs: {0}. Starting from epoch {1}.".format( epochs, cur_epoch + 1)) for e in range(epochs - cur_epoch): epoch = e + cur_epoch # Training. model.train() trainLossMeter = LossMeter() train_batch_bar = tqdm.tqdm(total=len(train_loader), desc="TrainBatch", position=0, leave=True) for batch_num, (orig_img, mask_img) in enumerate(train_loader): orig_img, mask_img = orig_img.float().to( device), mask_img.float().to(device) if net == "TransUNet-Base" or net == "TransUNet-Large": pred = model(orig_img) elif net == "SETR-PUP" or net == "SETR-MLA": if aux_layers is not None: pred, _ = model(orig_img) else: pred = model(orig_img) elif net == "UNet": pred = model(orig_img) loss_ce = ce_loss(pred, mask_img[:].long()) if use_dice_loss: loss_dice = dice_loss(pred, mask_img, softmax=True) loss = 0.5 * (loss_ce + loss_dice) else: loss = loss_ce # Backward Propagation, Update weight and metrics optim.zero_grad() loss.backward() optim.step() # update learning rate for param_group in optim.param_groups: orig_lr = param_group['lr'] param_group['lr'] = orig_lr * (1.0 - iter_count / iteration_num)**0.9 iter_count += 1 # Update loss trainLossMeter.update(loss.item()) # print status if (batch_num + 1) % print_freq == 0: status = 'Epoch: [{0}][{1}/{2}]\t' \ 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(epoch+1, batch_num+1, len(train_loader), loss=trainLossMeter) logger.info(status) # log loss to tensorboard if (batch_num + 1) % tensorboard_freq == 0: writer.add_scalar( 'Train_Loss_{0}'.format(tensorboard_freq), trainLossMeter.avg, epoch * (len(train_loader) / tensorboard_freq) + (batch_num + 1) / tensorboard_freq) train_batch_bar.update(1) writer.add_scalar('Train_Loss_epoch', trainLossMeter.avg, epoch) # Validation. model.eval() validLossMeter = LossMeter() valid_batch_bar = tqdm.tqdm(total=len(valid_loader), desc="ValidBatch", position=0, leave=True) with torch.no_grad(): for batch_num, (orig_img, mask_img) in enumerate(valid_loader): orig_img, mask_img = orig_img.float().to( device), mask_img.float().to(device) if net == "TransUNet-Base" or net == "TransUNet-Large": pred = model(orig_img) elif net == "SETR-PUP" or net == "SETR-MLA": if aux_layers is not None: pred, _ = model(orig_img) else: pred = model(orig_img) elif net == "UNet": pred = model(orig_img) loss_ce = ce_loss(pred, mask_img[:].long()) if use_dice_loss: loss_dice = dice_loss(pred, mask_img, softmax=True) loss = 0.5 * (loss_ce + loss_dice) else: loss = loss_ce # Update loss validLossMeter.update(loss.item()) # print status if (batch_num + 1) % print_freq == 0: status = 'Validation: [{0}][{1}/{2}]\t' \ 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(epoch+1, batch_num+1, len(valid_loader), loss=validLossMeter) logger.info(status) # log loss to tensorboard if (batch_num + 1) % tensorboard_freq == 0: writer.add_scalar( 'Valid_Loss_{0}'.format(tensorboard_freq), validLossMeter.avg, epoch * (len(valid_loader) / tensorboard_freq) + (batch_num + 1) / tensorboard_freq) valid_batch_bar.update(1) valid_loss = validLossMeter.avg writer.add_scalar('Valid_Loss_epoch', valid_loss, epoch) logger.info("Validation Loss of epoch [{0}/{1}]: {2}\n".format( epoch + 1, epochs, valid_loss)) # update optim scheduler scheduler.step() # save checkpoint is_best = valid_loss < best_loss best_loss_tmp = min(valid_loss, best_loss) if not is_best: epochs_since_improvement += 1 logger.info("Epochs since last improvement: %d\n" % (epochs_since_improvement, )) if epochs_since_improvement == early_stop_tolerance: break # early stopping. else: epochs_since_improvement = 0 state = { 'epoch': epoch, 'loss': best_loss_tmp, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optim.state_dict(), } torch.save(state, ckpt_src) logger.info("Checkpoint updated.") best_loss = best_loss_tmp epoch_bar.update(1) writer.close()
def train(self): if self.verbose: loss_meter = LossMeter() quesid2ans = {} best_valid = 0. print("Valid Oracle: %0.2f" % (self.oracle_score(self.val_loader) * 100)) from torch.utils.tensorboard import SummaryWriter self.writer = SummaryWriter(log_dir=self.args.log_dir) hparam_dict = {} for k, v in self.args.__dict__.items(): if type(v) in [int, float, str, bool, torch.Tensor]: hparam_dict[k] = v metric_dict = {} self.writer.add_hparams(hparam_dict, metric_dict) if self.args.distributed: dist.barrier() self.optim.zero_grad() for epoch in range(self.args.epochs): if self.start_epoch is not None: epoch += self.start_epoch self.model.train() if self.args.distributed: self.train_loader.sampler.set_epoch(epoch) if self.verbose: pbar = tqdm(total=len(self.train_loader), ncols=150) quesid2ans = {} for step_i, batch in enumerate(self.train_loader): update = True if self.args.update_freq > 1: if step_i == 0: update = False elif step_i % self.args.update_freq == 0 or step_i == len( self.train_loader) - 1: update = True else: update = False vis_feats = batch['vis_feats'].cuda() boxes = batch['boxes'].cuda() input_ids = batch['word_ids'].cuda() target = batch['targets'].cuda() ques_id = batch['question_ids'] B = len(batch['word_ids']) results = self.model( input_ids=input_ids, visual_feats=vis_feats, visual_pos=boxes, attention_mask=input_ids > 0, ) logit = results['logit'] assert logit.size() == target.size() assert logit.size() == (B, self.num_answers) loss = self.bce_loss(logit, target) loss.backward() if update: if not self.args.no_clip_grad: nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip_grad_norm) self.optim.step() self.lr_scheduler.step() for param in self.model.parameters(): param.grad = None try: lr = self.scheduler.get_last_lr()[0] except AttributeError: lr = self.args.lr if self.verbose: loss_meter.update(loss.item()) desc_str = f'Epoch {epoch} | LR {lr:.6f} | ' desc_str += f'Loss {loss_meter.val:4f} |' score, predict = logit.max(1) predict = predict.cpu().numpy() target = target.cpu().numpy() for qid, pred in zip(ques_id, predict): pred_ans = self.train_loader.dataset.raw_dataset.label2ans[ pred] quesid2ans[qid] = pred_ans pbar.set_description(desc_str) pbar.update(1) if self.args.distributed: dist.barrier() # score, label = logit.max(1) # for qid, l in zip(ques_id, label.cpu().numpy()): # ans = dset.label2ans[l] # quesid2ans[qid.item()] = ans if self.verbose: pbar.close() score = self.train_loader.evaluator.evaluate(quesid2ans) * 100. log_str = "\nEpoch %d: Train %0.2f" % (epoch, score) if not self.args.dry: self.writer.add_scalar(f'GQA/Train/score', score, epoch) # Validation valid_score = self.evaluate(self.val_loader) * 100. if valid_score > best_valid: best_valid = valid_score self.save("BEST") log_str += "\nEpoch %d: Valid %0.2f" % (epoch, valid_score) log_str += "\nEpoch %d: Best %0.2f\n" % (epoch, best_valid) if not self.args.dry: self.writer.add_scalar(f'GQA/Valid/score', valid_score, epoch) print(log_str) self.logger.info(log_str) if self.args.distributed: dist.barrier() if self.verbose: self.save("LAST")
def train(self): if self.verbose: loss_meter = LossMeter() best_valid = 0. best_epoch = 0 if 't5' in self.args.backbone: if self.args.use_vision: project_name = "VLT5_MMT" else: project_name = "T5_MMT" elif 'bart' in self.args.backbone: if self.args.use_vision: project_name = "VLBart_MMT" else: project_name = "Bart_MMT" wandb.init(project=project_name) wandb.run.name = self.args.run_name wandb.config.update(self.args) wandb.watch(self.model) src_dir = Path(__file__).resolve().parent base_path = str(src_dir.parent) src_dir = str(src_dir) wandb.save(os.path.join(src_dir + "/*.py"), base_path=base_path) if self.args.distributed: dist.barrier() global_step = 0 for epoch in range(self.args.epochs): if self.start_epoch is not None: epoch += self.start_epoch self.model.train() if self.args.distributed: self.train_loader.sampler.set_epoch(epoch) if self.verbose: pbar = tqdm(total=len(self.train_loader), ncols=120) epoch_results = { 'loss': 0., } for step_i, batch in enumerate(self.train_loader): # self.optim.zero_grad() if self.args.fp16 and _use_native_amp: with autocast(): if self.args.distributed: results = self.model.module.train_step(batch) else: results = self.model.train_step(batch) else: if self.args.distributed: results = self.model.module.train_step(batch) else: results = self.model.train_step(batch) loss = results['loss'] # print(f'GPU{self.args.gpu} after loss') if self.args.fp16 and _use_native_amp: self.scaler.scale(loss).backward() elif self.args.fp16 and _use_apex: with amp.scale_loss(loss, self.optim) as scaled_loss: scaled_loss.backward() else: loss.backward() # print(f'GPU{self.args.gpu} after backward') loss = loss.detach() # Update Parameters if self.args.clip_grad_norm > 0: if self.args.fp16 and _use_native_amp: self.scaler.unscale_(self.optim) torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.args.clip_grad_norm) elif self.args.fp16 and _use_apex: torch.nn.utils.clip_grad_norm_(amp.master_params( self.optim), self.args.clip_grad_norm) else: torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.args.clip_grad_norm) update = True if self.args.gradient_accumulation_steps > 1: if step_i == 0: update = False elif step_i % self.args.gradient_accumulation_steps == 0 or step_i == len(self.train_loader) - 1: update = True else: update = False if update: if self.args.fp16 and _use_native_amp: self.scaler.step(self.optim) self.scaler.update() else: self.optim.step() if self.lr_scheduler: self.lr_scheduler.step() # self.model.zero_grad() for param in self.model.parameters(): param.grad = None global_step += 1 for k, v in results.items(): if k in epoch_results: epoch_results[k] += v.item() if self.lr_scheduler: if version.parse(torch.__version__) >= version.parse("1.4"): lr = self.lr_scheduler.get_last_lr()[0] else: lr = self.lr_scheduler.get_lr()[0] else: try: lr = self.optim.get_lr()[0] except AttributeError: lr = self.args.lr if self.verbose: loss_meter.update(loss.item()) desc_str = f'Epoch {epoch} | LR {lr:.6f} | Steps {global_step}' desc_str += f' | Loss {loss_meter.val:4f}' pbar.set_description(desc_str) pbar.update(1) # if self.args.distributed: # dist.barrier() if self.verbose: pbar.close() # Validation valid_results = self.evaluate(self.val_loader) valid_score = valid_results['BLEU'] if valid_score > best_valid: best_valid = valid_score best_epoch = epoch self.save("BEST") log_str = '' log_str += pformat(valid_results) log_str += "\nEpoch %d: Valid BLEU %0.4f" % (epoch, valid_score) log_str += "\nEpoch %d: Best BLEU %0.4f\n" % (best_epoch, best_valid) wandb_log_dict = {} wandb_log_dict['Train/Loss'] = epoch_results['loss'] / len(self.train_loader) for score_name, score in valid_results.items(): wandb_log_dict[f'Valid/{score_name}'] = score wandb.log(wandb_log_dict, step=epoch) print(log_str) if self.args.distributed: dist.barrier() if self.verbose: self.save("LAST") # Test Set best_path = os.path.join(self.args.output, 'BEST') self.load(best_path) if isinstance(self.test_loader, list): for loader in self.test_loader: split = loader.dataset.source dump_path = os.path.join(self.args.output, f'submit_{split}_raw.txt') test_results = self.evaluate(loader, dump_path=dump_path) wandb_log_dict = {} for score_name, score in test_results.items(): wandb_log_dict[f'{split}/{score_name}'] = score wandb.log(wandb_log_dict, step=epoch) log_str = f'{split} set results\n' log_str += pformat(test_results) print(log_str) wandb.save(dump_path, base_path=self.args.output) print('\nUploaded', dump_path) else: split = loader.dataset.source dump_path = os.path.join(self.args.output, f'submit_{split}_raw.txt') test_results = self.evaluate(loader, dump_path=dump_path) wandb_log_dict = {} for score_name, score in test_results.items(): wandb_log_dict[f'{split}/{score_name}'] = score wandb.log(wandb_log_dict, step=epoch) log_str = f'{split} set results\n' log_str += pformat(test_results) print(log_str) wandb.save(dump_path, base_path=self.args.output) print('\nUploaded', dump_path) wandb.log({'finished': True})
class Finetuner: def __init__(self, args): self.args = args self.args.n_datasets = len(args.data) self.modelPath = Path('checkpoints') / args.expName self.logger = create_output_dir(args, self.modelPath) self.data = [DatasetSet(d, args.seq_len, args) for d in args.data] self.losses_recon = [ LossMeter(f'recon {i}') for i in range(self.args.n_datasets) ] self.loss_total = LossMeter('total') self.evals_recon = [ LossMeter(f'recon {i}') for i in range(self.args.n_datasets) ] self.eval_total = LossMeter('eval total') self.start_epoch = 0 #torch.manual_seed(args.seed) #torch.cuda.manual_seed(args.seed) #get the pretrained model checkpoints checkpoint = args.checkpoint.parent.glob(args.checkpoint.name + '_*.pth') checkpoint = [c for c in checkpoint if extract_id(c) in args.decoder][0] model_args = torch.load(args.checkpoint.parent / 'args.pth')[0] self.encoder = Encoder(model_args) self.decoder = WaveNet(model_args) self.encoder = Encoder(model_args) self.encoder.load_state_dict(torch.load(checkpoint)['encoder_state']) #encoder freeze for param in self.encoder.parameters(): param.requires_grad = False #self.logger.debug(f'encoder at start: {param}') self.decoder = WaveNet(model_args) self.decoder.load_state_dict(torch.load(checkpoint)['decoder_state']) #decoder freeze for param in self.decoder.layers[:-args.decoder_update].parameters(): param.requires_grad = False #self.logger.debug(f'decoder at start: {param}') self.encoder = torch.nn.DataParallel(self.encoder).cuda() self.decoder = torch.nn.DataParallel(self.decoder).cuda() self.model_optimizer = optim.Adam(chain(self.encoder.parameters(), self.decoder.parameters()), lr=args.lr) self.lr_manager = torch.optim.lr_scheduler.ExponentialLR( self.model_optimizer, args.lr_decay) self.lr_manager.step() def train_batch(self, x, x_aug, dset_num): 'train batch without considering the discriminator' x = x.float() x_aug = x_aug.float() z = self.encoder(x_aug) y = self.decoder(x, z) recon_loss = cross_entropy_loss(y, x) self.losses_recon[dset_num].add(recon_loss.data.cpu().numpy().mean()) loss = recon_loss.mean() self.model_optimizer.zero_grad() loss.backward() self.model_optimizer.step() self.loss_total.add(loss.data.item()) return loss.data.item() def train_epoch(self, epoch): for meter in self.losses_recon: meter.reset() self.loss_total.reset() self.decoder.train() n_batches = self.args.epoch_len with tqdm(total=n_batches, desc='Train epoch %d' % epoch) as train_enum: for batch_num in range(n_batches): if self.args.short and batch_num == 3: break if self.args.distributed: assert self.args.rank < self.args.n_datasets, "No. of workers must be equal to #dataset" # dset_num = (batch_num + self.args.rank) % self.args.n_datasets dset_num = self.args.rank else: dset_num = batch_num % self.args.n_datasets x, x_aug = next(self.data[dset_num].train_iter) x = wrap(x) x_aug = wrap(x_aug) batch_loss = self.train_batch(x, x_aug, dset_num) train_enum.set_description( f'Train (loss: {batch_loss:.2f}) epoch {epoch}') train_enum.update() def eval_batch(self, x, x_aug, dset_num): x, x_aug = x.float(), x_aug.float() z = self.encoder(x) y = self.decoder(x, z) recon_loss = cross_entropy_loss(y, x) self.evals_recon[dset_num].add(recon_loss.data.cpu().numpy().mean()) total_loss = recon_loss.mean().data.item() self.eval_total.add(total_loss) return total_loss def evaluate_epoch(self, epoch): for meter in self.evals_recon: meter.reset() self.eval_total.reset() self.encoder.eval() self.decoder.eval() n_batches = int(np.ceil(self.args.epoch_len / 10)) with tqdm(total=n_batches) as valid_enum, torch.no_grad(): for batch_num in range(n_batches): if self.args.short and batch_num == 10: break if self.args.distributed: assert self.args.rank < self.args.n_datasets, "No. of workers must be equal to #dataset" dset_num = self.args.rank else: dset_num = batch_num % self.args.n_datasets x, x_aug = next(self.data[dset_num].valid_iter) x = wrap(x) x_aug = wrap(x_aug) batch_loss = self.eval_batch(x, x_aug, dset_num) valid_enum.set_description( f'Test (loss: {batch_loss:.2f}) epoch {epoch}') valid_enum.update() @staticmethod def format_losses(meters): losses = [meter.summarize_epoch() for meter in meters] return ', '.join('{:.4f}'.format(x) for x in losses) def train_losses(self): meters = [*self.losses_recon] return self.format_losses(meters) def eval_losses(self): meters = [*self.evals_recon] return self.format_losses(meters) def finetune(self): best_eval = float('inf') for epoch in range(self.start_epoch, self.start_epoch + self.args.epochs): self.logger.info( f'Starting epoch, Rank {self.args.rank}, Dataset: {self.args.data[self.args.rank]}' ) self.train_epoch(epoch) self.evaluate_epoch(epoch) self.logger.info( f'Epoch %s Rank {self.args.rank} - Train loss: (%s), Test loss (%s)', epoch, self.train_losses(), self.eval_losses()) self.lr_manager.step() val_loss = self.eval_total.summarize_epoch() if val_loss < best_eval: self.save_model(f'bestmodel_{self.args.rank}.pth') best_eval = val_loss if not self.args.per_epoch: self.save_model(f'lastmodel_{self.args.rank}.pth') else: self.save_model(f'lastmodel_{epoch}_rank_{self.args.rank}.pth') if self.args.is_master: torch.save([self.args, epoch], '%s/args.pth' % self.modelPath) self.logger.debug('Ended epoch') def save_model(self, filename): save_path = self.modelPath / filename torch.save( { 'encoder_state': self.encoder.module.state_dict(), 'decoder_state': self.decoder.module.state_dict(), 'model_optimizer_state': self.model_optimizer.state_dict(), 'dataset': self.args.rank, }, save_path) self.logger.debug(f'Saved model to {save_path}')