def __init__(self, model, data_tuple_dict, config): self.model = model self.criterion = nn.BCEWithLogitsLoss() base_optim = Lamb(params=self.model.parameters(), lr=1e-5, weight_decay=1.2e-6, min_trust=0.25) self.optim = Lookahead(base_optimizer=base_optim, k=5, alpha=0.8) self.lr_scheduler = CyclicLR(self.optim, base_lr=1e-5, max_lr=5e-5, cycle_momentum=False) self.train_tuple = data_tuple_dict["train_tuple"] self.valid_tuple = data_tuple_dict["valid_tuple"] self.test_tuple = data_tuple_dict["test_tuple"] self.device = (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")) self.output = home + "/snap/" os.makedirs(self.output, exist_ok=True) self.model.to(self.device) self.adaptive = config["adaptive_enable"] self.measure_flops = config["measure_flops"] if self.measure_flops: from thop import clever_format, profile self.sparse = sparse = config["sparse_enable"] if config["load_model"] == None: load_lxmert_qa(load_lxmert_qa_path, self.model, label2ans=self.train_tuple[0].label2ans)
def main(argv) -> None: del argv # Process the configuration from flags. config = process_config() if config.mode != "evaluate": # Define the datasets. train_dataset = TripletDataset(batch_size=config.batch_size, folder="datasets/triplet_aracati/train", x_shape=config.input_shape, y_shape=config.output_shape, is_validation=False) valid_dataset = TripletDataset( batch_size=config.batch_size, folder="datasets/triplet_aracati/validation", x_shape=config.input_shape, y_shape=config.output_shape, is_validation=True) # Define the sonar model. son_loss = TripletLoss() son_ranger = Lookahead( RectifiedAdam(learning_rate=config.learning_rate), sync_period=6, slow_step_size=0.5) son_model = EncodeNet(filters=config.filters, loss=son_loss, optimizer=son_ranger) # Define the satellite model. sat_loss = TripletLoss() sat_ranger = Lookahead( RectifiedAdam(learning_rate=config.learning_rate), sync_period=6, slow_step_size=0.5) sat_model = EncodeNet(filters=config.filters, loss=sat_loss, optimizer=sat_ranger) # Define the logger. logger = Logger() # Define the trainer. trainer = TripletTrainer(son_model=son_model, sat_model=sat_model, logger=logger, train_dataset=train_dataset, valid_dataset=valid_dataset) if config.mode == "restore": trainer.load_checkpoint() trainer.train() else: logging.fatal("Evaluation not implemented yet.")
def main(argv) -> None: del argv # Process the configuration from flags. config = process_config() if config.mode != "evaluate": # Define the datasets. train_dataset = MatchingDataset( batch_size=config.batch_size, folder="datasets/matching_aracati/train", x_shape=config.input_shape, y_shape=config.output_shape, is_evaluating=False) valid_dataset = MatchingDataset( batch_size=config.batch_size, folder="datasets/matching_aracati/validation", x_shape=config.input_shape, y_shape=config.output_shape, is_evaluating=False) # Define the model. loss = tf.keras.losses.BinaryCrossentropy() ranger = Lookahead(RectifiedAdam(learning_rate=config.learning_rate), sync_period=6, slow_step_size=0.5) model = DizygoticNet(filters=config.filters, loss=loss, optimizer=ranger) # Define the logger. logger = Logger() # Define the trainer. trainer = MatchingTrainer(model=model, logger=logger, train_dataset=train_dataset, valid_dataset=valid_dataset) if config.mode == "restore": trainer.load_checkpoint() trainer.train() else: # Define the test dataset. test_dataset = MatchingDataset(batch_size=1, folder="datasets/matching_aracati/test", x_shape=config.input_shape, y_shape=config.output_shape, is_evaluating=True) # Define the model. model = DizygoticNet(filters=config.filters, loss=None, optimizer=None) # Define the evaluator. evaluator = MatchingEvaluator(model=model, dataset=test_dataset) evaluator.load_checkpoint() evaluator.evaluate()
def main(argv) -> None: del argv # Process the configuration from flags. config = process_config() if config.mode != "evaluate": # Define the datasets. train_dataset = GeneralDataset(batch_size=config.batch_size, folder="datasets/general_aracati/train", x_shape=config.input_shape, y_shape=config.satellite_shape, z_shape=config.output_shape) valid_dataset = GeneralDataset( batch_size=config.batch_size, folder="datasets/general_aracati/validation", x_shape=config.input_shape, y_shape=config.satellite_shape, z_shape=config.output_shape) # Define the model. loss = tf.keras.losses.MeanAbsoluteError() ranger = Lookahead(RectifiedAdam(learning_rate=config.learning_rate), sync_period=6, slow_step_size=0.5) model = WNet(filters=config.filters, loss=loss, optimizer=ranger) # Define the logger. logger = Logger() # Define the trainer. trainer = GeneralTrainer(model=model, logger=logger, train_dataset=train_dataset, valid_dataset=valid_dataset) if config.mode == "restore": trainer.load_checkpoint() trainer.train() else: # Define the test dataset. test_dataset = GeneralDataset(batch_size=1, folder="datasets/general_aracati/test", x_shape=config.input_shape, y_shape=config.satellite_shape, z_shape=config.output_shape) # Define the model. model = WNet(filters=config.filters, loss=None, optimizer=None) # Define the evaluator. evaluator = GeneralEvaluator(model=model, dataset=test_dataset) evaluator.load_checkpoint() evaluator.evaluate()
def main(argv) -> None: del argv # Process the configuration from flags. config = process_config() if config.mode != "evaluate": # Define the datasets. train_dataset = SegmentationDataset( batch_size=config.batch_size, folder="datasets/segmentation_aracati/train", x_shape=config.input_shape, y_shape=config.output_shape) valid_dataset = SegmentationDataset( batch_size=config.batch_size, folder="datasets/segmentation_aracati/validation", x_shape=config.input_shape, y_shape=config.output_shape) # Define the model. loss = tf.keras.losses.CategoricalCrossentropy() ranger = Lookahead(RectifiedAdam(learning_rate=config.learning_rate), sync_period=6, slow_step_size=0.5) model = UNet(filters=config.filters, loss=loss, optimizer=ranger) # Define the logger. logger = Logger() # Define the trainer. trainer = SegmentationTrainer(model=model, logger=logger, train_dataset=train_dataset, valid_dataset=valid_dataset) if config.mode == "restore": trainer.load_checkpoint() trainer.train() else: logging.fatal("Evaluation not implemented yet.")
def Over9000(params, alpha=0.5, k=6, *args, **kwargs): ralamb = Ralamb(params, *args, **kwargs) return Lookahead(ralamb, alpha, k)
class Learner: def __init__(self, model, data_tuple_dict, config): self.model = model self.criterion = nn.BCEWithLogitsLoss() base_optim = Lamb(params=self.model.parameters(), lr=1e-5, weight_decay=1.2e-6, min_trust=0.25) self.optim = Lookahead(base_optimizer=base_optim, k=5, alpha=0.8) self.lr_scheduler = CyclicLR(self.optim, base_lr=1e-5, max_lr=5e-5, cycle_momentum=False) self.train_tuple = data_tuple_dict["train_tuple"] self.valid_tuple = data_tuple_dict["valid_tuple"] self.test_tuple = data_tuple_dict["test_tuple"] self.device = (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")) self.output = home + "/snap/" os.makedirs(self.output, exist_ok=True) self.model.to(self.device) self.adaptive = config["adaptive_enable"] self.measure_flops = config["measure_flops"] if self.measure_flops: from thop import clever_format, profile self.sparse = sparse = config["sparse_enable"] if config["load_model"] == None: load_lxmert_qa(load_lxmert_qa_path, self.model, label2ans=self.train_tuple[0].label2ans) def train(self, num_epochs): dset, loader, evaluator = self.train_tuple best_valid = 0.0 iter_wrapper = lambda x: tqdm(x, total=len(loader)) for epoch in range(num_epochs): t0 = time.time() quesid2ans = {} for i, (ques_id, feats, boxes, sent, target) in iter_wrapper(enumerate(loader)): self.model.train() self.optim.zero_grad() feats, boxes, target = ( feats.to(self.device), boxes.to(self.device), target.to(self.device), ) logit = self.model(feats, boxes, sent) assert logit.dim() == target.dim() == 2 loss = self.criterion(logit, target) * logit.size(1) if self.adaptive: adapt_span_loss = 0.0 for l in self.model.lxrt_encoder.model.bert.encoder.layer: adapt_span_loss += l.attention.self.adaptive_span.get_loss( ) for l in self.model.lxrt_encoder.model.bert.encoder.x_layers: adapt_span_loss += ( l.visual_attention.att.adaptive_span.get_loss()) for l in self.model.lxrt_encoder.model.bert.encoder.x_layers: adapt_span_loss += l.lang_self_att.self.adaptive_span.get_loss( ) for l in self.model.lxrt_encoder.model.bert.encoder.x_layers: adapt_span_loss += l.visn_self_att.self.adaptive_span.get_loss( ) for l in self.model.lxrt_encoder.model.bert.encoder.r_layers: adapt_span_loss += l.attention.self.adaptive_span.get_loss( ) loss += adapt_span_loss ##################################################### loss.backward() nn.utils.clip_grad_norm_(self.model.parameters(), 5.0) self.optim.step() self.lr_scheduler.step() score, label = logit.max(1) for qid, l in zip(ques_id, label.cpu().numpy()): ans = dset.label2ans[l] quesid2ans[qid.item()] = ans ##################################################### if self.adaptive: for l in self.model.lxrt_encoder.model.bert.encoder.layer: l.attention.self.adaptive_span.clamp_param() for l in self.model.lxrt_encoder.model.bert.encoder.x_layers: l.visual_attention.att.adaptive_span.clamp_param() for l in self.model.lxrt_encoder.model.bert.encoder.x_layers: l.lang_self_att.self.adaptive_span.clamp_param() for l in self.model.lxrt_encoder.model.bert.encoder.x_layers: l.visn_self_att.self.adaptive_span.clamp_param() for l in self.model.lxrt_encoder.model.bert.encoder.r_layers: l.attention.self.adaptive_span.clamp_param() ##################################################### log_str = "\nEpoch %d: Train %0.2f\n" % ( epoch, evaluator.evaluate(quesid2ans) * 100.0, ) log_str += "Loss: " + str(loss.item()) + "\t" if self.adaptive: log_str += "\tAdapt Span Loss: " + str( adapt_span_loss.item()) + "\n" if self.measure_flops: macs, params = profile(self.model, inputs=(feats, boxes, sent)) macs, params = clever_format([macs, params], "%.3f") log_str += "\nMacs: " + macs + "\tParams: " + params + "\n" if self.adaptive: for layer_idx, i in enumerate( self.model.lxrt_encoder.model.bert.encoder.layer): l = i.attention.self.adaptive_span.get_current_avg_span() log_str += "Self Language %d %d\t" % (layer_idx, l) log_str += "\n" for layer_idx, i in enumerate( self.model.lxrt_encoder.model.bert.encoder.x_layers): l = i.visual_attention.att.adaptive_span.get_current_avg_span( ) log_str += "Cross %d %d\t" % (layer_idx, l) log_str += "\n" for layer_idx, i in enumerate( self.model.lxrt_encoder.model.bert.encoder.x_layers): l = i.lang_self_att.self.adaptive_span.get_current_avg_span( ) log_str += "Cross Self Language %d %d\t" % (layer_idx, l) log_str += "\n" for layer_idx, i in enumerate( self.model.lxrt_encoder.model.bert.encoder.x_layers): l = i.visn_self_att.self.adaptive_span.get_current_avg_span( ) log_str += "Cross Self Vision %d %d\t" % (layer_idx, l) log_str += "\n" for layer_idx, i in enumerate( self.model.lxrt_encoder.model.bert.encoder.r_layers): l = i.attention.self.adaptive_span.get_current_avg_span() log_str += "Self Vision %d %d\t" % (layer_idx, l) # if self.sparse: # alpha_val = {} # for l in self.model.lxrt_encoder.model.bert.encoder.layer: # alpha_val["lang_layer"] = l.attention.self.entmax_alpha.alpha_chooser # for l in self.model.lxrt_encoder.model.bert.encoder.x_layers: # alpha_val["cross_layer"] = l.visual_attention.att.entmax_alpha.alpha_chooser # for l in self.model.lxrt_encoder.model.bert.encoder.x_layers: # alpha_val["cross_lang_layer"] = l.lang_self_att.self.entmax_alpha.alpha_chooser # for l in self.model.lxrt_encoder.model.bert.encoder.x_layers: # alpha_val["cross_vision_layer"] = l.visn_self_att.self.entmax_alpha.alpha_chooser # for l in self.model.lxrt_encoder.model.bert.encoder.r_layers: # alpha_val["vision_layer"] = l.attention.self.entmax_alpha.alpha_chooser # print("Alpha Values from Entmax have been saved at "+ home+'/snap/alpha_val_'+str(epoch)+'.pth') # torch.save(alpha_val, home+'/snap/alpha_val_' + str(epoch)+ '.pth') ##################################################### if self.valid_tuple is not None: # Do Validation valid_score = self.evaluate(self.valid_tuple) if valid_score > best_valid: best_valid = valid_score self.save("BEST") log_str += "Epoch %d: Valid %0.2f\n" % ( epoch, valid_score * 100.0, ) + "Epoch %d: Best %0.2f\n" % (epoch, best_valid * 100.0) current_time = time.time() - t0 print(current_time) log_str += "Time elpased for epoch %f\n" % (current_time) print(log_str, end="") with open(self.output + "/log.log", "a") as f: f.write(log_str) f.flush() self.save("LAST") def predict(self, eval_tuple, dump=None): """ Predict the answers to questions in a data split. :param eval_tuple: The data tuple to be evaluated. :param dump: The path of saved file to dump results. :return: A dict of question_id to answer. """ self.model.eval() dset, loader, evaluator = eval_tuple quesid2ans = {} iter_wrapper = lambda x: tqdm(x, total=len(loader)) print("Predict in progress") for i, datum_tuple in iter_wrapper(enumerate(loader)): ques_id, feats, boxes, sent = datum_tuple[: 4] # Avoid seeing ground truth with torch.no_grad(): feats, boxes = feats.to(self.device), boxes.to(self.device) logit = self.model(feats, boxes, sent) score, label = logit.max(1) for qid, l in zip(ques_id, label.cpu().numpy()): ans = dset.label2ans[l] quesid2ans[qid.item()] = ans if dump is not None: evaluator.dump_result(quesid2ans, dump) return quesid2ans def evaluate(self, eval_tuple: DataTuple, dump=None): """Evaluate all data in data_tuple.""" quesid2ans = self.predict(eval_tuple, dump) return eval_tuple.evaluator.evaluate(quesid2ans) @staticmethod def oracle_score(data_loader): quesid2ans = {} for i, (ques_id, feats, boxes, sent, target) in enumerate(data_loader): _, label = target.max(1) for qid, l in zip(ques_id, label.cpu().numpy()): ans = dset.label2ans[l] quesid2ans[qid.item()] = ans return evaluator.evaluate(quesid2ans) def save(self, name): torch.save(self.model.state_dict(), os.path.join(self.output, "%s.pth" % name)) def load(self, path): state_dict = torch.load("%s.pth" % path, map_location=self.device) self.model.load_state_dict(state_dict)
def preload(self): logging.debug(json.dumps(self.config, indent=4, sort_keys=True)) fields = SemEvalECFR_Dataset.prepare_fields( pad_t=self.tokenizer.pad_token_id) data = SemEvalECFR_Dataset(TRAIN_F_SEMEVAL2020_5, fields=fields, tokenizer=self.tokenizer, model_type=self.config["model_type"]) with open(self._RANDOM_STATE_PATH, "rb") as f: random_state = pickle.load(f) train, val = data.split(split_ratio=0.9, random_state=random_state) train_iter = BucketIterator(train, sort_key=lambda x: -(len(x.sentence)), shuffle=True, sort=False, batch_size=self.config["batch_size"], train=True, repeat=False, device=self.device) val_iter = BucketIterator( val, sort_key=lambda x: -(len(x.sentence)), shuffle=False, sort=True, batch_size=self.config["validation_batch_size"], train=False, repeat=False, device=self.device) def load_model() -> TransformerForECFR: return TransformerForECFR(self.config, sep_token=self.tokenizer.sep_token_id, pad_token=self.tokenizer.pad_token_id) if "model_path" in self.config: m = torch.load(self.config["model_path"]) state_dict = m.state_dict() if hasattr(m, "state_dict") else m model = load_model() model.load_state_dict(state_dict) else: model = load_model() model = model.to(self.device) model.config = self.config logging.info(f"Models has {count_parameters(model)} parameters") param_sizes, param_shapes = report_parameters(model) param_sizes = "\n'".join(str(param_sizes).split(", '")) param_shapes = "\n'".join(str(param_shapes).split(", '")) logging.debug(f"Model structure:\n{param_sizes}\n{param_shapes}\n") if self.config["optimizer"] == "adam" or self.config[ "optimizer"] == "adamw": optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=self.config["learning_rate"], weight_decay=self.config["weight_decay"]) else: raise NotImplementedError( f"Option {self.config['optimizer']} for \"optimizer\" setting is undefined." ) if self.config["lookahead_optimizer"]: optimizer = Lookahead(optimizer, k=self.config["lookahead_K"], alpha=self.config["lookahead_alpha"]) return model, optimizer, train_iter, val_iter
def get_optimizer(model, optimizer_name, optimizer_params, scheduler_name, scheduler_params, n_epochs): opt_lower = optimizer_name.lower() opt_look_ahed = optimizer_params["lookahead"] if opt_lower == 'sgd': optimizer = optim.SGD(model.parameters(), lr=optimizer_params["lr"], momentum=optimizer_params["momentum"], weight_decay=optimizer_params["weight_decay"], nesterov=True) elif opt_lower == 'adam': optimizer = optim.Adam(model.parameters(), lr=optimizer_params["lr"], betas=(0.9, 0.999), eps=1e-08, weight_decay=0) elif opt_lower == 'adamw': optimizer = torch.optim.AdamW( model.parameters(), lr=optimizer_params["lr"], weight_decay=optimizer_params["weight_decay"], eps=optimizer_params["opt_eps"]) elif opt_lower == 'nadam': optimizer = torch.optim.Nadam( model.parameters(), lr=optimizer_params["lr"], weight_decay=optimizer_params["weight_decay"], eps=optimizer_params["opt_eps"]) elif opt_lower == 'radam': optimizer = RAdam(model.parameters(), lr=optimizer_params["lr"], weight_decay=optimizer_params["weight_decay"], eps=optimizer_params["opt_eps"]) elif opt_lower == "adabelief": optimizer = AdaBelief(model.parameters(), lr=optimizer_params["lr"], eps=1e-8, weight_decay=optimizer_params["weight_decay"]) elif opt_lower == "adamp": optimizer = AdamP(model.parameters(), lr=optimizer_params["lr"], weight_decay=optimizer_params["weight_decay"]) else: assert False and "Invalid optimizer" raise ValueError if opt_look_ahed: optimizer = Lookahead(optimizer, alpha=0.5, k=5) if scheduler_name == "CosineAnnealingWarmRestarts": scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( optimizer, eta_min=scheduler_params["eta_min"], T_0=scheduler_params["T_0"], T_mult=scheduler_params["T_multi"], ) elif scheduler_name == "WarmRestart": scheduler = WarmRestart(optimizer, T_max=scheduler_params["T_max"], T_mult=scheduler_params["T_mul"], eta_min=scheduler_params["eta_min"]) elif scheduler_name == "MultiStepLR": scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=scheduler_params["schedule"], gamma=scheduler_params["gamma"]) if scheduler_params["warmup_factor"] > 0: scheduler = GradualWarmupSchedulerV2( optimizer, multiplier=scheduler_params["warmup_factor"], total_epoch=1, after_scheduler=scheduler) return optimizer, scheduler