def __init__(self, model, data_tuple_dict, config):
     self.model = model
     self.criterion = nn.BCEWithLogitsLoss()
     base_optim = Lamb(params=self.model.parameters(),
                       lr=1e-5,
                       weight_decay=1.2e-6,
                       min_trust=0.25)
     self.optim = Lookahead(base_optimizer=base_optim, k=5, alpha=0.8)
     self.lr_scheduler = CyclicLR(self.optim,
                                  base_lr=1e-5,
                                  max_lr=5e-5,
                                  cycle_momentum=False)
     self.train_tuple = data_tuple_dict["train_tuple"]
     self.valid_tuple = data_tuple_dict["valid_tuple"]
     self.test_tuple = data_tuple_dict["test_tuple"]
     self.device = (torch.device("cuda")
                    if torch.cuda.is_available() else torch.device("cpu"))
     self.output = home + "/snap/"
     os.makedirs(self.output, exist_ok=True)
     self.model.to(self.device)
     self.adaptive = config["adaptive_enable"]
     self.measure_flops = config["measure_flops"]
     if self.measure_flops:
         from thop import clever_format, profile
     self.sparse = sparse = config["sparse_enable"]
     if config["load_model"] == None:
         load_lxmert_qa(load_lxmert_qa_path,
                        self.model,
                        label2ans=self.train_tuple[0].label2ans)
Example #2
0
def main(argv) -> None:
    del argv

    # Process the configuration from flags.
    config = process_config()

    if config.mode != "evaluate":
        # Define the datasets.
        train_dataset = TripletDataset(batch_size=config.batch_size,
                                       folder="datasets/triplet_aracati/train",
                                       x_shape=config.input_shape,
                                       y_shape=config.output_shape,
                                       is_validation=False)

        valid_dataset = TripletDataset(
            batch_size=config.batch_size,
            folder="datasets/triplet_aracati/validation",
            x_shape=config.input_shape,
            y_shape=config.output_shape,
            is_validation=True)

        # Define the sonar model.
        son_loss = TripletLoss()
        son_ranger = Lookahead(
            RectifiedAdam(learning_rate=config.learning_rate),
            sync_period=6,
            slow_step_size=0.5)
        son_model = EncodeNet(filters=config.filters,
                              loss=son_loss,
                              optimizer=son_ranger)

        # Define the satellite model.
        sat_loss = TripletLoss()
        sat_ranger = Lookahead(
            RectifiedAdam(learning_rate=config.learning_rate),
            sync_period=6,
            slow_step_size=0.5)
        sat_model = EncodeNet(filters=config.filters,
                              loss=sat_loss,
                              optimizer=sat_ranger)

        # Define the logger.
        logger = Logger()

        # Define the trainer.
        trainer = TripletTrainer(son_model=son_model,
                                 sat_model=sat_model,
                                 logger=logger,
                                 train_dataset=train_dataset,
                                 valid_dataset=valid_dataset)
        if config.mode == "restore":
            trainer.load_checkpoint()
        trainer.train()
    else:
        logging.fatal("Evaluation not implemented yet.")
Example #3
0
def main(argv) -> None:
    del argv

    # Process the configuration from flags.
    config = process_config()

    if config.mode != "evaluate":
        # Define the datasets.
        train_dataset = MatchingDataset(
            batch_size=config.batch_size,
            folder="datasets/matching_aracati/train",
            x_shape=config.input_shape,
            y_shape=config.output_shape,
            is_evaluating=False)

        valid_dataset = MatchingDataset(
            batch_size=config.batch_size,
            folder="datasets/matching_aracati/validation",
            x_shape=config.input_shape,
            y_shape=config.output_shape,
            is_evaluating=False)

        # Define the model.
        loss = tf.keras.losses.BinaryCrossentropy()
        ranger = Lookahead(RectifiedAdam(learning_rate=config.learning_rate),
                           sync_period=6,
                           slow_step_size=0.5)
        model = DizygoticNet(filters=config.filters,
                             loss=loss,
                             optimizer=ranger)

        # Define the logger.
        logger = Logger()

        # Define the trainer.
        trainer = MatchingTrainer(model=model,
                                  logger=logger,
                                  train_dataset=train_dataset,
                                  valid_dataset=valid_dataset)
        if config.mode == "restore":
            trainer.load_checkpoint()
        trainer.train()
    else:
        # Define the test dataset.
        test_dataset = MatchingDataset(batch_size=1,
                                       folder="datasets/matching_aracati/test",
                                       x_shape=config.input_shape,
                                       y_shape=config.output_shape,
                                       is_evaluating=True)

        # Define the model.
        model = DizygoticNet(filters=config.filters, loss=None, optimizer=None)

        # Define the evaluator.
        evaluator = MatchingEvaluator(model=model, dataset=test_dataset)
        evaluator.load_checkpoint()
        evaluator.evaluate()
Example #4
0
def main(argv) -> None:
    del argv

    # Process the configuration from flags.
    config = process_config()

    if config.mode != "evaluate":
        # Define the datasets.
        train_dataset = GeneralDataset(batch_size=config.batch_size,
                                       folder="datasets/general_aracati/train",
                                       x_shape=config.input_shape,
                                       y_shape=config.satellite_shape,
                                       z_shape=config.output_shape)

        valid_dataset = GeneralDataset(
            batch_size=config.batch_size,
            folder="datasets/general_aracati/validation",
            x_shape=config.input_shape,
            y_shape=config.satellite_shape,
            z_shape=config.output_shape)

        # Define the model.
        loss = tf.keras.losses.MeanAbsoluteError()
        ranger = Lookahead(RectifiedAdam(learning_rate=config.learning_rate),
                           sync_period=6,
                           slow_step_size=0.5)
        model = WNet(filters=config.filters, loss=loss, optimizer=ranger)

        # Define the logger.
        logger = Logger()

        # Define the trainer.
        trainer = GeneralTrainer(model=model,
                                 logger=logger,
                                 train_dataset=train_dataset,
                                 valid_dataset=valid_dataset)
        if config.mode == "restore":
            trainer.load_checkpoint()
        trainer.train()
    else:
        # Define the test dataset.
        test_dataset = GeneralDataset(batch_size=1,
                                      folder="datasets/general_aracati/test",
                                      x_shape=config.input_shape,
                                      y_shape=config.satellite_shape,
                                      z_shape=config.output_shape)

        # Define the model.
        model = WNet(filters=config.filters, loss=None, optimizer=None)

        # Define the evaluator.
        evaluator = GeneralEvaluator(model=model, dataset=test_dataset)
        evaluator.load_checkpoint()
        evaluator.evaluate()
Example #5
0
def main(argv) -> None:
    del argv

    # Process the configuration from flags.
    config = process_config()

    if config.mode != "evaluate":
        # Define the datasets.
        train_dataset = SegmentationDataset(
            batch_size=config.batch_size,
            folder="datasets/segmentation_aracati/train",
            x_shape=config.input_shape,
            y_shape=config.output_shape)

        valid_dataset = SegmentationDataset(
            batch_size=config.batch_size,
            folder="datasets/segmentation_aracati/validation",
            x_shape=config.input_shape,
            y_shape=config.output_shape)

        # Define the model.
        loss = tf.keras.losses.CategoricalCrossentropy()
        ranger = Lookahead(RectifiedAdam(learning_rate=config.learning_rate),
                           sync_period=6,
                           slow_step_size=0.5)
        model = UNet(filters=config.filters, loss=loss, optimizer=ranger)

        # Define the logger.
        logger = Logger()

        # Define the trainer.
        trainer = SegmentationTrainer(model=model,
                                      logger=logger,
                                      train_dataset=train_dataset,
                                      valid_dataset=valid_dataset)

        if config.mode == "restore":
            trainer.load_checkpoint()
        trainer.train()
    else:
        logging.fatal("Evaluation not implemented yet.")
Example #6
0
def Over9000(params, alpha=0.5, k=6, *args, **kwargs):
     ralamb = Ralamb(params, *args, **kwargs)
     return Lookahead(ralamb, alpha, k)
class Learner:
    def __init__(self, model, data_tuple_dict, config):
        self.model = model
        self.criterion = nn.BCEWithLogitsLoss()
        base_optim = Lamb(params=self.model.parameters(),
                          lr=1e-5,
                          weight_decay=1.2e-6,
                          min_trust=0.25)
        self.optim = Lookahead(base_optimizer=base_optim, k=5, alpha=0.8)
        self.lr_scheduler = CyclicLR(self.optim,
                                     base_lr=1e-5,
                                     max_lr=5e-5,
                                     cycle_momentum=False)
        self.train_tuple = data_tuple_dict["train_tuple"]
        self.valid_tuple = data_tuple_dict["valid_tuple"]
        self.test_tuple = data_tuple_dict["test_tuple"]
        self.device = (torch.device("cuda")
                       if torch.cuda.is_available() else torch.device("cpu"))
        self.output = home + "/snap/"
        os.makedirs(self.output, exist_ok=True)
        self.model.to(self.device)
        self.adaptive = config["adaptive_enable"]
        self.measure_flops = config["measure_flops"]
        if self.measure_flops:
            from thop import clever_format, profile
        self.sparse = sparse = config["sparse_enable"]
        if config["load_model"] == None:
            load_lxmert_qa(load_lxmert_qa_path,
                           self.model,
                           label2ans=self.train_tuple[0].label2ans)

    def train(self, num_epochs):
        dset, loader, evaluator = self.train_tuple
        best_valid = 0.0
        iter_wrapper = lambda x: tqdm(x, total=len(loader))

        for epoch in range(num_epochs):
            t0 = time.time()
            quesid2ans = {}
            for i, (ques_id, feats, boxes, sent,
                    target) in iter_wrapper(enumerate(loader)):
                self.model.train()
                self.optim.zero_grad()
                feats, boxes, target = (
                    feats.to(self.device),
                    boxes.to(self.device),
                    target.to(self.device),
                )

                logit = self.model(feats, boxes, sent)
                assert logit.dim() == target.dim() == 2
                loss = self.criterion(logit, target) * logit.size(1)

                if self.adaptive:

                    adapt_span_loss = 0.0
                    for l in self.model.lxrt_encoder.model.bert.encoder.layer:
                        adapt_span_loss += l.attention.self.adaptive_span.get_loss(
                        )

                    for l in self.model.lxrt_encoder.model.bert.encoder.x_layers:
                        adapt_span_loss += (
                            l.visual_attention.att.adaptive_span.get_loss())

                    for l in self.model.lxrt_encoder.model.bert.encoder.x_layers:
                        adapt_span_loss += l.lang_self_att.self.adaptive_span.get_loss(
                        )

                    for l in self.model.lxrt_encoder.model.bert.encoder.x_layers:
                        adapt_span_loss += l.visn_self_att.self.adaptive_span.get_loss(
                        )

                    for l in self.model.lxrt_encoder.model.bert.encoder.r_layers:
                        adapt_span_loss += l.attention.self.adaptive_span.get_loss(
                        )

                    loss += adapt_span_loss
                #####################################################
                loss.backward()
                nn.utils.clip_grad_norm_(self.model.parameters(), 5.0)
                self.optim.step()
                self.lr_scheduler.step()
                score, label = logit.max(1)
                for qid, l in zip(ques_id, label.cpu().numpy()):
                    ans = dset.label2ans[l]
                    quesid2ans[qid.item()] = ans
                #####################################################
                if self.adaptive:
                    for l in self.model.lxrt_encoder.model.bert.encoder.layer:
                        l.attention.self.adaptive_span.clamp_param()

                    for l in self.model.lxrt_encoder.model.bert.encoder.x_layers:
                        l.visual_attention.att.adaptive_span.clamp_param()

                    for l in self.model.lxrt_encoder.model.bert.encoder.x_layers:
                        l.lang_self_att.self.adaptive_span.clamp_param()

                    for l in self.model.lxrt_encoder.model.bert.encoder.x_layers:
                        l.visn_self_att.self.adaptive_span.clamp_param()

                    for l in self.model.lxrt_encoder.model.bert.encoder.r_layers:
                        l.attention.self.adaptive_span.clamp_param()
            #####################################################
            log_str = "\nEpoch %d: Train %0.2f\n" % (
                epoch,
                evaluator.evaluate(quesid2ans) * 100.0,
            )
            log_str += "Loss: " + str(loss.item()) + "\t"

            if self.adaptive:
                log_str += "\tAdapt Span Loss: " + str(
                    adapt_span_loss.item()) + "\n"

            if self.measure_flops:
                macs, params = profile(self.model, inputs=(feats, boxes, sent))
                macs, params = clever_format([macs, params], "%.3f")
                log_str += "\nMacs: " + macs + "\tParams: " + params + "\n"

            if self.adaptive:
                for layer_idx, i in enumerate(
                        self.model.lxrt_encoder.model.bert.encoder.layer):
                    l = i.attention.self.adaptive_span.get_current_avg_span()
                    log_str += "Self Language %d %d\t" % (layer_idx, l)
                log_str += "\n"
                for layer_idx, i in enumerate(
                        self.model.lxrt_encoder.model.bert.encoder.x_layers):
                    l = i.visual_attention.att.adaptive_span.get_current_avg_span(
                    )
                    log_str += "Cross %d %d\t" % (layer_idx, l)
                log_str += "\n"
                for layer_idx, i in enumerate(
                        self.model.lxrt_encoder.model.bert.encoder.x_layers):
                    l = i.lang_self_att.self.adaptive_span.get_current_avg_span(
                    )
                    log_str += "Cross Self Language %d %d\t" % (layer_idx, l)
                log_str += "\n"
                for layer_idx, i in enumerate(
                        self.model.lxrt_encoder.model.bert.encoder.x_layers):
                    l = i.visn_self_att.self.adaptive_span.get_current_avg_span(
                    )
                    log_str += "Cross Self Vision %d %d\t" % (layer_idx, l)
                log_str += "\n"
                for layer_idx, i in enumerate(
                        self.model.lxrt_encoder.model.bert.encoder.r_layers):
                    l = i.attention.self.adaptive_span.get_current_avg_span()
                    log_str += "Self Vision %d %d\t" % (layer_idx, l)
            #             if self.sparse:
            #                 alpha_val = {}
            #                 for l in self.model.lxrt_encoder.model.bert.encoder.layer:
            #                     alpha_val["lang_layer"] = l.attention.self.entmax_alpha.alpha_chooser
            #                 for l in self.model.lxrt_encoder.model.bert.encoder.x_layers:
            #                     alpha_val["cross_layer"] = l.visual_attention.att.entmax_alpha.alpha_chooser
            #                 for l in self.model.lxrt_encoder.model.bert.encoder.x_layers:
            #                     alpha_val["cross_lang_layer"] = l.lang_self_att.self.entmax_alpha.alpha_chooser
            #                 for l in self.model.lxrt_encoder.model.bert.encoder.x_layers:
            #                     alpha_val["cross_vision_layer"] = l.visn_self_att.self.entmax_alpha.alpha_chooser
            #                 for l in self.model.lxrt_encoder.model.bert.encoder.r_layers:
            #                     alpha_val["vision_layer"] = l.attention.self.entmax_alpha.alpha_chooser
            #                 print("Alpha Values from Entmax have been saved at "+ home+'/snap/alpha_val_'+str(epoch)+'.pth')
            #                 torch.save(alpha_val, home+'/snap/alpha_val_' + str(epoch)+ '.pth')
            #####################################################
            if self.valid_tuple is not None:  # Do Validation
                valid_score = self.evaluate(self.valid_tuple)
                if valid_score > best_valid:
                    best_valid = valid_score
                    self.save("BEST")

                log_str += "Epoch %d: Valid %0.2f\n" % (
                    epoch,
                    valid_score * 100.0,
                ) + "Epoch %d: Best %0.2f\n" % (epoch, best_valid * 100.0)

            current_time = time.time() - t0
            print(current_time)
            log_str += "Time elpased for epoch %f\n" % (current_time)
            print(log_str, end="")

            with open(self.output + "/log.log", "a") as f:
                f.write(log_str)
                f.flush()

        self.save("LAST")

    def predict(self, eval_tuple, dump=None):
        """
        Predict the answers to questions in a data split.

        :param eval_tuple: The data tuple to be evaluated.
        :param dump: The path of saved file to dump results.
        :return: A dict of question_id to answer.
        """
        self.model.eval()
        dset, loader, evaluator = eval_tuple
        quesid2ans = {}
        iter_wrapper = lambda x: tqdm(x, total=len(loader))
        print("Predict in progress")
        for i, datum_tuple in iter_wrapper(enumerate(loader)):
            ques_id, feats, boxes, sent = datum_tuple[:
                                                      4]  # Avoid seeing ground truth
            with torch.no_grad():
                feats, boxes = feats.to(self.device), boxes.to(self.device)
                logit = self.model(feats, boxes, sent)
                score, label = logit.max(1)
                for qid, l in zip(ques_id, label.cpu().numpy()):
                    ans = dset.label2ans[l]
                    quesid2ans[qid.item()] = ans
        if dump is not None:
            evaluator.dump_result(quesid2ans, dump)
        return quesid2ans

    def evaluate(self, eval_tuple: DataTuple, dump=None):
        """Evaluate all data in data_tuple."""
        quesid2ans = self.predict(eval_tuple, dump)
        return eval_tuple.evaluator.evaluate(quesid2ans)

    @staticmethod
    def oracle_score(data_loader):
        quesid2ans = {}
        for i, (ques_id, feats, boxes, sent, target) in enumerate(data_loader):
            _, label = target.max(1)
            for qid, l in zip(ques_id, label.cpu().numpy()):
                ans = dset.label2ans[l]
                quesid2ans[qid.item()] = ans
        return evaluator.evaluate(quesid2ans)

    def save(self, name):
        torch.save(self.model.state_dict(),
                   os.path.join(self.output, "%s.pth" % name))

    def load(self, path):
        state_dict = torch.load("%s.pth" % path, map_location=self.device)
        self.model.load_state_dict(state_dict)
    def preload(self):
        logging.debug(json.dumps(self.config, indent=4, sort_keys=True))
        fields = SemEvalECFR_Dataset.prepare_fields(
            pad_t=self.tokenizer.pad_token_id)
        data = SemEvalECFR_Dataset(TRAIN_F_SEMEVAL2020_5,
                                   fields=fields,
                                   tokenizer=self.tokenizer,
                                   model_type=self.config["model_type"])
        with open(self._RANDOM_STATE_PATH, "rb") as f:
            random_state = pickle.load(f)
        train, val = data.split(split_ratio=0.9, random_state=random_state)
        train_iter = BucketIterator(train,
                                    sort_key=lambda x: -(len(x.sentence)),
                                    shuffle=True,
                                    sort=False,
                                    batch_size=self.config["batch_size"],
                                    train=True,
                                    repeat=False,
                                    device=self.device)
        val_iter = BucketIterator(
            val,
            sort_key=lambda x: -(len(x.sentence)),
            shuffle=False,
            sort=True,
            batch_size=self.config["validation_batch_size"],
            train=False,
            repeat=False,
            device=self.device)

        def load_model() -> TransformerForECFR:
            return TransformerForECFR(self.config,
                                      sep_token=self.tokenizer.sep_token_id,
                                      pad_token=self.tokenizer.pad_token_id)

        if "model_path" in self.config:
            m = torch.load(self.config["model_path"])
            state_dict = m.state_dict() if hasattr(m, "state_dict") else m
            model = load_model()
            model.load_state_dict(state_dict)
        else:
            model = load_model()
        model = model.to(self.device)
        model.config = self.config
        logging.info(f"Models has {count_parameters(model)} parameters")
        param_sizes, param_shapes = report_parameters(model)
        param_sizes = "\n'".join(str(param_sizes).split(", '"))
        param_shapes = "\n'".join(str(param_shapes).split(", '"))
        logging.debug(f"Model structure:\n{param_sizes}\n{param_shapes}\n")
        if self.config["optimizer"] == "adam" or self.config[
                "optimizer"] == "adamw":
            optimizer = AdamW(filter(lambda p: p.requires_grad,
                                     model.parameters()),
                              lr=self.config["learning_rate"],
                              weight_decay=self.config["weight_decay"])
        else:
            raise NotImplementedError(
                f"Option {self.config['optimizer']} for \"optimizer\" setting is undefined."
            )
        if self.config["lookahead_optimizer"]:
            optimizer = Lookahead(optimizer,
                                  k=self.config["lookahead_K"],
                                  alpha=self.config["lookahead_alpha"])
        return model, optimizer, train_iter, val_iter
Example #9
0
def get_optimizer(model, optimizer_name, optimizer_params, scheduler_name,
                  scheduler_params, n_epochs):
    opt_lower = optimizer_name.lower()

    opt_look_ahed = optimizer_params["lookahead"]
    if opt_lower == 'sgd':
        optimizer = optim.SGD(model.parameters(),
                              lr=optimizer_params["lr"],
                              momentum=optimizer_params["momentum"],
                              weight_decay=optimizer_params["weight_decay"],
                              nesterov=True)
    elif opt_lower == 'adam':
        optimizer = optim.Adam(model.parameters(),
                               lr=optimizer_params["lr"],
                               betas=(0.9, 0.999),
                               eps=1e-08,
                               weight_decay=0)
    elif opt_lower == 'adamw':
        optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=optimizer_params["lr"],
            weight_decay=optimizer_params["weight_decay"],
            eps=optimizer_params["opt_eps"])
    elif opt_lower == 'nadam':
        optimizer = torch.optim.Nadam(
            model.parameters(),
            lr=optimizer_params["lr"],
            weight_decay=optimizer_params["weight_decay"],
            eps=optimizer_params["opt_eps"])
    elif opt_lower == 'radam':
        optimizer = RAdam(model.parameters(),
                          lr=optimizer_params["lr"],
                          weight_decay=optimizer_params["weight_decay"],
                          eps=optimizer_params["opt_eps"])
    elif opt_lower == "adabelief":
        optimizer = AdaBelief(model.parameters(),
                              lr=optimizer_params["lr"],
                              eps=1e-8,
                              weight_decay=optimizer_params["weight_decay"])

    elif opt_lower == "adamp":
        optimizer = AdamP(model.parameters(),
                          lr=optimizer_params["lr"],
                          weight_decay=optimizer_params["weight_decay"])
    else:
        assert False and "Invalid optimizer"
        raise ValueError

    if opt_look_ahed:
        optimizer = Lookahead(optimizer, alpha=0.5, k=5)

    if scheduler_name == "CosineAnnealingWarmRestarts":
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer,
            eta_min=scheduler_params["eta_min"],
            T_0=scheduler_params["T_0"],
            T_mult=scheduler_params["T_multi"],
        )
    elif scheduler_name == "WarmRestart":
        scheduler = WarmRestart(optimizer,
                                T_max=scheduler_params["T_max"],
                                T_mult=scheduler_params["T_mul"],
                                eta_min=scheduler_params["eta_min"])
    elif scheduler_name == "MultiStepLR":
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer,
            milestones=scheduler_params["schedule"],
            gamma=scheduler_params["gamma"])
    if scheduler_params["warmup_factor"] > 0:
        scheduler = GradualWarmupSchedulerV2(
            optimizer,
            multiplier=scheduler_params["warmup_factor"],
            total_epoch=1,
            after_scheduler=scheduler)

    return optimizer, scheduler