class Trainer(object): def __init__(self, args): super(Trainer, self).__init__() self.args = args if cfg.SEED > 0: random.seed(cfg.SEED) torch.manual_seed(cfg.SEED) torch.cuda.manual_seed_all(cfg.SEED) self.num_gpus = torch.cuda.device_count() self.distributed = self.num_gpus > 1 if self.distributed: torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group( backend="nccl", init_method="env://" ) self.device = torch.device("cuda") self.rl_stage = False self.setup_logging() self.setup_dataset() self.setup_network() self.val_evaler = Evaler( eval_ids = cfg.DATA_LOADER.VAL_ID, gv_feat = cfg.DATA_LOADER.VAL_GV_FEAT, att_feats = cfg.DATA_LOADER.VAL_ATT_FEATS, eval_annfile = cfg.INFERENCE.VAL_ANNFILE ) self.test_evaler = Evaler( eval_ids = cfg.DATA_LOADER.TEST_ID, gv_feat = cfg.DATA_LOADER.TEST_GV_FEAT, att_feats = cfg.DATA_LOADER.TEST_ATT_FEATS, eval_annfile = cfg.INFERENCE.TEST_ANNFILE ) self.scorer = Scorer() def setup_logging(self): self.logger = logging.getLogger(cfg.LOGGER_NAME) self.logger.setLevel(logging.INFO) if self.distributed and dist.get_rank() > 0: return ch = logging.StreamHandler(stream=sys.stdout) ch.setLevel(logging.INFO) formatter = logging.Formatter("[%(levelname)s: %(asctime)s] %(message)s") ch.setFormatter(formatter) self.logger.addHandler(ch) if not os.path.exists(cfg.ROOT_DIR): os.makedirs(cfg.ROOT_DIR) fh = logging.FileHandler(os.path.join(cfg.ROOT_DIR, cfg.LOGGER_NAME + '.txt')) fh.setLevel(logging.INFO) fh.setFormatter(formatter) self.logger.addHandler(fh) self.logger.info('Training with config:') self.logger.info(pprint.pformat(cfg)) def setup_network(self): model = models.create(cfg.MODEL.TYPE) if self.distributed: # this should be removed if we update BatchNorm stats self.model = torch.nn.parallel.DistributedDataParallel( model.to(self.device), device_ids = [self.args.local_rank], output_device = self.args.local_rank, broadcast_buffers = False ) else: self.model = torch.nn.DataParallel(model).cuda() if self.args.resume > 0: self.model.load_state_dict( torch.load(self.snapshot_path("caption_model", self.args.resume), map_location=lambda storage, loc: storage) ) self.optim = Optimizer(self.model) self.xe_criterion = losses.create(cfg.LOSSES.XE_TYPE).cuda() self.rl_criterion = losses.create(cfg.LOSSES.RL_TYPE).cuda() def setup_dataset(self): self.coco_set = datasets.coco_dataset.CocoDataset( image_ids_path = cfg.DATA_LOADER.TRAIN_ID, input_seq = cfg.DATA_LOADER.INPUT_SEQ_PATH, target_seq = cfg.DATA_LOADER.TARGET_SEQ_PATH, gv_feat_path = cfg.DATA_LOADER.TRAIN_GV_FEAT, att_feats_folder = cfg.DATA_LOADER.TRAIN_ATT_FEATS, seq_per_img = cfg.DATA_LOADER.SEQ_PER_IMG, max_feat_num = cfg.DATA_LOADER.MAX_FEAT ) def setup_loader(self, epoch): self.training_loader = datasets.data_loader.load_train( self.distributed, epoch, self.coco_set) def eval(self, epoch): if (epoch + 1) % cfg.SOLVER.TEST_INTERVAL != 0: return None if self.distributed and dist.get_rank() > 0: return None val_res = self.val_evaler(self.model, 'val_' + str(epoch + 1)) self.logger.info('######## Epoch (VAL)' + str(epoch + 1) + ' ########') self.logger.info(str(val_res)) test_res = self.test_evaler(self.model,'test_' + str(epoch + 1)) self.logger.info('######## Epoch (TEST)' + str(epoch + 1) + ' ########') self.logger.info(str(test_res)) val = 0 for score_type, weight in zip(cfg.SCORER.TYPES, cfg.SCORER.WEIGHTS): val -= val_res[score_type] * weight return val def snapshot_path(self, name, epoch): snapshot_folder = os.path.join(cfg.ROOT_DIR, 'snapshot') return os.path.join(snapshot_folder, name + "_" + str(epoch) + ".pth") def save_model(self, epoch): if (epoch + 1) % cfg.SOLVER.SNAPSHOT_ITERS != 0: return if self.distributed and dist.get_rank() > 0: return snapshot_folder = os.path.join(cfg.ROOT_DIR, 'snapshot') if not os.path.exists(snapshot_folder): os.mkdir(snapshot_folder) torch.save(self.model.state_dict(), self.snapshot_path("caption_model", epoch+1)) def make_kwargs(self, indices, input_seq, target_seq, gv_feat, att_feats, att_mask): seq_mask = (input_seq > 0).type(torch.cuda.LongTensor) seq_mask[:,0] += 1 seq_mask_sum = seq_mask.sum(-1) max_len = int(seq_mask_sum.max()) input_seq = input_seq[:, 0:max_len].contiguous() target_seq = target_seq[:, 0:max_len].contiguous() kwargs = { cfg.PARAM.INDICES: indices, cfg.PARAM.INPUT_SENT: input_seq, cfg.PARAM.TARGET_SENT: target_seq, cfg.PARAM.GLOBAL_FEAT: gv_feat, cfg.PARAM.ATT_FEATS: att_feats, cfg.PARAM.ATT_FEATS_MASK: att_mask } return kwargs def scheduled_sampling(self, epoch): if epoch > cfg.TRAIN.SCHEDULED_SAMPLING.START: frac = (epoch - cfg.TRAIN.SCHEDULED_SAMPLING.START) // cfg.TRAIN.SCHEDULED_SAMPLING.INC_EVERY ss_prob = min(cfg.TRAIN.SCHEDULED_SAMPLING.INC_PROB * frac, cfg.TRAIN.SCHEDULED_SAMPLING.MAX_PROB) self.model.module.ss_prob = ss_prob def display(self, iteration, data_time, batch_time, losses, loss_info): if iteration % cfg.SOLVER.DISPLAY != 0: return if self.distributed and dist.get_rank() > 0: return info_str = ' (DataTime/BatchTime: {:.3}/{:.3}) losses = {:.5}'.format(data_time.avg, batch_time.avg, losses.avg) self.logger.info('Iteration ' + str(iteration) + info_str +', lr = ' + str(self.optim.get_lr())) for name in sorted(loss_info): self.logger.info(' ' + name + ' = ' + str(loss_info[name])) data_time.reset() batch_time.reset() losses.reset() def forward(self, kwargs): if self.rl_stage == False: logit = self.model(**kwargs) loss, loss_info = self.xe_criterion(logit, kwargs[cfg.PARAM.TARGET_SENT]) else: ids = kwargs[cfg.PARAM.INDICES] gv_feat = kwargs[cfg.PARAM.GLOBAL_FEAT] att_feats = kwargs[cfg.PARAM.ATT_FEATS] att_mask = kwargs[cfg.PARAM.ATT_FEATS_MASK] # max kwargs['BEAM_SIZE'] = 1 kwargs['GREEDY_DECODE'] = True kwargs[cfg.PARAM.GLOBAL_FEAT] = gv_feat kwargs[cfg.PARAM.ATT_FEATS] = att_feats kwargs[cfg.PARAM.ATT_FEATS_MASK] = att_mask self.model.eval() with torch.no_grad(): seq_max, logP_max = self.model.module.decode(**kwargs) self.model.train() rewards_max, rewards_info_max = self.scorer(ids, seq_max.data.cpu().numpy().tolist()) rewards_max = utils.expand_numpy(rewards_max) ids = utils.expand_numpy(ids) gv_feat = utils.expand_tensor(gv_feat, cfg.DATA_LOADER.SEQ_PER_IMG) att_feats = utils.expand_tensor(att_feats, cfg.DATA_LOADER.SEQ_PER_IMG) att_mask = utils.expand_tensor(att_mask, cfg.DATA_LOADER.SEQ_PER_IMG) # sample kwargs['BEAM_SIZE'] = 1 kwargs['GREEDY_DECODE'] = False kwargs[cfg.PARAM.GLOBAL_FEAT] = gv_feat kwargs[cfg.PARAM.ATT_FEATS] = att_feats kwargs[cfg.PARAM.ATT_FEATS_MASK] = att_mask seq_sample, logP_sample = self.model.module.decode(**kwargs) rewards_sample, rewards_info_sample = self.scorer(ids, seq_sample.data.cpu().numpy().tolist()) rewards = rewards_sample - rewards_max rewards = torch.from_numpy(rewards).float().cuda() loss = self.rl_criterion(seq_sample, logP_sample, rewards) loss_info = {} for key in rewards_info_sample: loss_info[key + '_sample'] = rewards_info_sample[key] for key in rewards_info_max: loss_info[key + '_max'] = rewards_info_max[key] return loss, loss_info def train(self): self.model.train() self.optim.zero_grad() iteration = 0 for epoch in range(cfg.SOLVER.MAX_EPOCH): if epoch == cfg.TRAIN.REINFORCEMENT.START: self.rl_stage = True self.setup_loader(epoch) start = time.time() data_time = AverageMeter() batch_time = AverageMeter() losses = AverageMeter() for _, (indices, input_seq, target_seq, gv_feat, att_feats, att_mask) in enumerate(self.training_loader): data_time.update(time.time() - start) input_seq = input_seq.cuda() target_seq = target_seq.cuda() gv_feat = gv_feat.cuda() att_feats = att_feats.cuda() att_mask = att_mask.cuda() kwargs = self.make_kwargs(indices, input_seq, target_seq, gv_feat, att_feats, att_mask) loss, loss_info = self.forward(kwargs) loss.backward() utils.clip_gradient(self.optim.optimizer, self.model, cfg.SOLVER.GRAD_CLIP_TYPE, cfg.SOLVER.GRAD_CLIP) self.optim.step() self.optim.zero_grad() self.optim.scheduler_step('Iter') batch_time.update(time.time() - start) start = time.time() losses.update(loss.item()) self.display(iteration, data_time, batch_time, losses, loss_info) iteration += 1 if self.distributed: dist.barrier() self.save_model(epoch) val = self.eval(epoch) self.optim.scheduler_step('Epoch', val) self.scheduled_sampling(epoch) if self.distributed: dist.barrier()
class Trainer(object): def __init__(self, args): super(Trainer, self).__init__() self.args = args self.setup_gpu() self.setup_logging() self.setup_loader() self.setup_network() def setup_gpu(self): if args.local_rank == -1: self.device = torch.device( "cuda" if torch.cuda.is_available() else "cpu") self.n_gpu = torch.cuda.device_count() self.distributed = False else: torch.cuda.set_device(args.local_rank) self.device = torch.device("cuda", args.local_rank) self.n_gpu = 1 torch.distributed.init_process_group( backend="nccl", init_method="env://", timeout=timedelta(minutes=180)) self.distributed = True print("device: {} n_gpu: {}, distributed training: {}".format( self.device, self.n_gpu, bool(args.local_rank != -1))) if cfg.SEED > 0: random.seed(cfg.SEED) torch.manual_seed(cfg.SEED) torch.cuda.manual_seed_all(cfg.SEED) def setup_logging(self): self.logger = logging.getLogger(cfg.LOGGER_NAME) self.logger.setLevel(logging.INFO) if self.distributed and dist.get_rank() > 0: return ch = logging.StreamHandler(stream=sys.stdout) ch.setLevel(logging.INFO) formatter = logging.Formatter( "[%(levelname)s: %(asctime)s] %(message)s") ch.setFormatter(formatter) self.logger.addHandler(ch) fh = logging.FileHandler( os.path.join(cfg.ROOT_DIR, cfg.LOGGER_NAME + '.txt')) fh.setLevel(logging.INFO) fh.setFormatter(formatter) self.logger.addHandler(fh) self.logger.info('Training with config:') self.logger.info(pprint.pformat(cfg)) def setup_loader(self): self.tokenizer = BertTokenizer.from_pretrained( cfg.TRAIN.BERT_MODEL, do_lower_case=cfg.TRAIN.DO_LOWER_CASE) self.train_dataset_loader, self.train_dataset, self.train_sampler = load_concap_train( args.local_rank, self.tokenizer) def setup_network(self): config = BertConfig.from_json_file(cfg.CONFIG_FILE) if cfg.TRAIN.FROM_PRETRAINED: model = BaseBertPreTraining.from_pretrained( cfg.TRAIN.FROM_PRETRAINED, config) else: model = BaseBertPreTraining(config) model.to(self.device) if args.local_rank != -1: self.model = torch.nn.parallel.DistributedDataParallel( model, find_unused_parameters=True, device_ids=[self.args.local_rank], output_device=self.args.local_rank, broadcast_buffers=False) elif self.n_gpu > 1: self.model = torch.nn.DataParallel(model) else: self.model = model epoch_steps = len(self.train_dataset_loader) n_steps = epoch_steps * cfg.SOLVER.NUM_TRAIN_EPOCHS self.optim = Optimizer(self.model, epoch_steps=epoch_steps, n_steps=n_steps) def display(self, iteration, batch_time, losses, loss_info): if iteration % cfg.SOLVER.DISPLAY != 0: return if self.distributed and dist.get_rank() > 0: return info_str = ' (BatchTime: {:.3}) losses = {:.5}'.format( batch_time.avg, losses.avg) self.logger.info('Iteration ' + str(iteration) + info_str + ', lr = ' + str(self.optim.get_lr())) for name in sorted(loss_info): self.logger.info(' ' + name + ' = ' + str(loss_info[name])) batch_time.reset() losses.reset() def snapshot_path(self, name, epoch): snapshot_folder = os.path.join(cfg.ROOT_DIR, 'snapshot') return os.path.join(snapshot_folder, name + "_" + str(epoch) + ".bin") def save_model(self, epoch): if (epoch + 1) % cfg.SOLVER.SNAPSHOT_ITERS != 0: return if self.distributed and dist.get_rank() > 0: return snapshot_folder = os.path.join(cfg.ROOT_DIR, 'snapshot') if not os.path.exists(snapshot_folder): os.mkdir(snapshot_folder) model_to_save = (self.model.module if hasattr(self.model, "module") else self.model) torch.save(model_to_save.state_dict(), self.snapshot_path("pytorch_model", epoch + 1)) def train(self): max_num_iter = len(self.train_dataset_loader) for epochId in range(int(cfg.SOLVER.NUM_TRAIN_EPOCHS)): if self.train_sampler is not None: self.train_sampler.set_epoch(epochId) self.model.train() start = time.time() batch_time = AverageMeter() for step, batch in enumerate(self.train_dataset_loader): iterId = step + (epochId * max_num_iter) self.optim.zero_grad() batch = tuple( t.cuda(device=self.device, non_blocking=True) for t in batch) input_ids, input_mask, segment_ids, lm_label_ids, image_feat, \ image_loc, imgfeat_cls_prob, imgfeat_label, imgfeat_mask = (batch) # TODO Train model self.save_model(epochId) if self.distributed: dist.barrier()