예제 #1
0
    def custom_test_predictions(self, polarize=False):
        model = resnets.resnet18(pretrained=False).to(self.device)
        state = torch.load("./models/resnet-18.pt")
        model.load_state_dict(state["model_state_dict"])
        model.eval()

        test_ids, test_preds = [], []
        for idx, batch in enumerate(self.test_loader):
            img, ids = batch
            img = img.to(self.device)
            with torch.no_grad():
                out = F.softmax(model(img), dim=1).detach().cpu().numpy()
            test_ids.extend(ids.cpu().numpy().tolist())
            test_preds.extend(out[:, 0].tolist())
            common.progress_bar(progress=(idx + 1) / len(self.test_loader),
                                status="")

        common.progress_bar(progress=1.0, status="")
        sub_df = pd.DataFrame({"id": test_ids, "p_real": test_preds})
        sub_df["id"] = sub_df["id"].astype("int")
        sub_df["p_real"] = sub_df["p_real"].astype("float").apply(
            lambda x: min(max(0.05, x), 0.95))
        sub_df = sub_df.sort_values(by="id", ascending=True)
        if polarize:
            sub_df["p_real"] = sub_df["p_real"].apply(
                DeepfakeClassifier.polarize_predictions)
            sub_df.to_csv(os.path.join(self.output_dir,
                                       "polarized_test_predictions.csv"),
                          index=False)
        else:
            sub_df.to_csv(os.path.join(self.output_dir,
                                       "test_predictions.csv"),
                          index=False)
예제 #2
0
    def get_test_predictions(self, polarize=True):
        self.model.eval()
        if self.args["load"] is None:
            self.load_model(self.output_dir)

        test_ids, test_preds = [], []
        for idx, batch in enumerate(self.test_loader):
            img, ids = batch
            img = img.to(self.device)
            with torch.no_grad():
                out = self.model(img).detach().cpu().numpy()
            test_ids.extend(ids.cpu().numpy().tolist())
            test_preds.extend(out[:, 1].tolist())
            common.progress_bar(progress=(idx + 1) / len(self.test_loader),
                                status="")

        common.progress_bar(progress=1.0, status="")
        sub_df = pd.DataFrame({"id": test_ids, "p_real": test_preds})
        sub_df["id"] = sub_df["id"].astype("int")
        sub_df["p_real"] = sub_df["p_real"].astype("float").apply(
            lambda x: min(max(0.05, x), 0.95))
        sub_df = sub_df.sort_values(by="id", ascending=True)
        if polarize:
            sub_df["p_real"] = sub_df["p_real"].apply(
                DeepfakeClassifier.polarize_predictions)
            sub_df.to_csv(os.path.join(self.output_dir,
                                       "polarized_test_predictions.csv"),
                          index=False)
        else:
            sub_df.to_csv(os.path.join(self.output_dir,
                                       "test_predictions.csv"),
                          index=False)
예제 #3
0
    def train(self):
        for epoch in range(self.config["epochs"] - self.done_epochs):
            self.logger.record(f'Epoch {epoch+1}/{self.config["epochs"]}',
                               mode='train')
            train_meter = common.AverageMeter()

            for idx, batch in enumerate(self.train_loader):
                train_metrics = self.train_one_step(batch)
                train_meter.add(train_metrics)
                wandb.log({"Train loss": train_meter.return_metrics()["loss"]})
                common.progress_bar(progress=(idx + 1) /
                                    len(self.train_loader),
                                    status=train_meter.return_msg())

            common.progress_bar(progress=1.0, status=train_meter.return_msg())
            self.logger.write(train_meter.return_msg(), mode='train')
            self.adjust_learning_rate(epoch + 1)
            wandb.log({
                "Train accuracy":
                train_meter.return_metrics()["accuracy"],
                "Train F1 score":
                train_meter.return_metrics()["f1"],
                "Epoch":
                epoch + 1
            })

            if (epoch + 1) % self.config["eval_every"] == 0:
                self.logger.record(f'Epoch {epoch+1}/{self.config["epochs"]}',
                                   mode='val')
                val_meter = common.AverageMeter()
                for idx, batch in enumerate(self.val_loader):
                    val_metrics = self.validate_one_step(batch)
                    val_meter.add(val_metrics)
                    common.progress_bar(progress=(idx + 1) /
                                        len(self.val_loader),
                                        status=val_meter.return_msg())

                common.progress_bar(progress=1.0,
                                    status=val_meter.return_msg())
                self.logger.write(val_meter.return_msg(), mode='val')
                wandb.log({
                    "Validation loss":
                    val_meter.return_metrics()["loss"],
                    "Validation accuracy":
                    val_meter.return_metrics()["accuracy"],
                    "Validation F1 score":
                    val_meter.return_metrics()["f1"],
                    "Epoch":
                    epoch + 1
                })

                if val_meter.return_metrics()["loss"] < self.best_val_loss:
                    self.best_val_loss = val_meter.return_metrics()["loss"]
                    self.save_model()

        print("\n\n")
        self.logger.record("Finished training! Generating test predictions...",
                           mode='info')
        self.get_test_predictions(self.config.get("polarize_predictions",
                                                  True))
예제 #4
0
def __progress(percent, char, newline):
    """
        Print single-lined progress information.
    """
    if char == None:
        char = ""

    status_line = \
        common.progress_bar(78, percent, "Progress:", 16, "=", char, True)
    sys.stdout.write(status_line + "\r")
    sys.stdout.flush()

    if newline:
        print
예제 #5
0
파일: check.py 프로젝트: RapidLzj/Scheduler
def check (tel, yr, mn, dy, run="") :
    """ check fits header, and generate a check list
    args:
        tel: telescope brief code
        yr: year of obs date, 4-digit year
        mn: year of obs date, 1 to 12
        dy: day of obs date, 0 to 31, or extended
        run: run code, default is `yyyymm`
    """
    site = schdutil.load_basic(tel)
    mjd18 = common.sky.mjd_of_night(yr, mn, dy, site)
    if run is None or run == "" :
        run = "{year:04d}{month:02d}".format(year=yr, month=mn)
    # input and output filename
    filelist = "{tel}/obsed/{run}/files.J{day:04d}.lst".format(tel=tel, run=run, day=mjd18)
    chklist  = "{tel}/obsed/{run}/check.J{day:04d}.lst".format(tel=tel, run=run, day=mjd18)

    if not os.path.isfile(filelist) :
        print ("ERROR!! File list NOT exists: \'{0}\'".format(filelist))
        return

    # load file info
    flst = open(filelist, "r").readlines()
    clst = []
    fcnt, i = len(flst), 0
    pb = common.progress_bar(0, value_from=0, value_to=fcnt)
    for f in flst :
        #i += 1
        #util.progress_bar(i, fcnt)
        info = headerinfo.headerinfo(f.strip())
        if info is not None :
            clst.append(info)
        pb.step()
    #print ("\n")
    pb.end()

    # output check list
    with open(chklist, "w") as f :
        for c in clst :
            f.write("{}\n".format(c))
    #f.close()

    print ("Check OK! {0} files from `{1}`.\n".format(len(clst), filelist))
예제 #6
0
def collect ( tel, yr, mn, dy, run="" ) :
    """ collect info from check list, compare with exposure mode and plan, make obsed list
    args:
        tel: telescope brief code
        yr: year of obs date, 4-digit year
        mn: year of obs date, 1 to 12
        dy: day of obs date, 0 to 31, or extended
        run: run code, default is `yyyymm`
    """
    site = schdutil.load_basic(tel)
    mjd18 = common.sky.mjd_of_night(yr, mn, dy, site)
    if run is None or run == "" :
        run = "{year:04d}{month:02d}".format(year=yr, month=mn)
    # search check filenames
    checklist = "{tel}/obsed/{run}/check.J{day:04d}.lst".format(tel=tel, run=run, day=mjd18)
    obsedlist = "{tel}/obsed/{run}/obsed.J{day:04d}.lst".format(tel=tel, run=run, day=mjd18)

    if not os.path.isfile(checklist) :
        print ("ERROR!! Check list NOT exists: \'{0}\'".format(checklist))
        return

    # load configure file
    plans = schdutil.load_expplan(tel)
    modes = schdutil.load_expmode(tel)

    # load check files
    chk = []
    lines = open(checklist, "r").readlines()
    pb = common.progress_bar(0, value_from=0, value_to=len(lines), value_step=0.5)
    for line in lines :
        c = schdutil.check_info.parse(line)
        mode = c.mode()
        c.mode = mode
        if mode in modes :
            c.code = modes[mode].code
            c.factor = modes[mode].factor
        else :
            c.code = -1
            c.factor = 0.0
        # remove tailing "x", which means dithered
        if c.object.endswith("x") :
            c.object = c.object[0:-1]
        chk.append(c)
        pb.step()

    # init a empty 2-d dict `objs`, keys: obj name, plan code
    # use dict to provide an array with easy index
    uobj = set([c.object for c in chk])
    emptyobj = {-1:0.0}
    for p in plans :
        emptyobj[plans[p].code] = 0.0
    objmap = {}
    for o in uobj :
        objmap[o] = emptyobj.copy()

    # fill check info into objs
    for c in chk :
        objmap[c.object][c.code] += c.factor
        pb.step()
    pb.end()

    # output obsed file
    # get an fixed order, so no random between different system
    plancode = list(plans.keys())
    plancode.sort()
    with open(obsedlist, "w") as f :
        f.write("#{:<11s}".format("Object"))
        for p in plancode :
            f.write(" {:>4s}".format(plans[p].name[0:4]))
        f.write("\n\n")
        for o in objmap :
            ot = "{:<12s}".format(o)
            ft = ["{:>4.1f}".format(objmap[o][p]) for p in plancode]
            tt = ot + " " + " ".join(ft) + "\n"
            f.write(tt)
    #f.close()

    print ("Collect OK! {0} objects from {1} files of `{2}`.\n".format(
        len(objmap), len(lines), checklist))
    def train(self, run_id=0):
        print()
        for epoch in range(self.done_epochs, self.config['epochs'] + 1):
            train_meter = common.AverageMeter()
            val_meter = common.AverageMeter()
            self.train_data, self.val_data = self.train_data.to(
                self.device), self.val_data.to(self.device)
            self.logger.record('Epoch [{:3d}/{}]'.format(
                epoch, self.config['epochs']),
                               mode='train')
            if self.scheduler is not None:
                self.adjust_learning_rate(epoch + 1)

            for idx in range(len(self.train_loader)):
                batch = self.train_loader.flow()
                train_metrics = self.train_on_batch(batch)
                wandb.log({
                    'Loss': train_metrics['Loss'],
                    'MAPE': train_metrics['MAPE'],
                    'Epoch': epoch
                })
                train_meter.add(train_metrics)
                common.progress_bar(progress=idx / len(self.train_loader),
                                    status=train_meter.return_msg())

            common.progress_bar(progress=1, status=train_meter.return_msg())
            self.logger.write(train_meter.return_msg(), mode='train')
            wandb.log({
                'Learning rate': self.optim.param_groups[0]['lr'],
                'Epoch': epoch
            })

            # Save state
            self.save_state(epoch)

            # Validation
            if epoch % self.config['eval_every'] == 0:
                self.logger.record('Epoch [{:3d}/{}]'.format(
                    epoch, self.config['epochs']),
                                   mode='val')
                val_metrics, forecast = self.validate()
                self.compare_predictions(forecast)
                val_meter.add(val_metrics)

                self.logger.record(val_meter.return_msg(), mode='val')
                val_metrics = val_meter.return_metrics()
                wandb.log({
                    'Validation loss': val_metrics['Loss'],
                    'Validation MAPE': val_metrics['MAPE'],
                    'Epoch': epoch
                })

                if val_metrics['Loss'] < self.best_val:
                    self.best_val = val_metrics['Loss']
                    self.save_model()

                if val_metrics['MAPE'] < self.best_mape:
                    self.best_mape = val_metrics['MAPE']

        self.stack_breakdown(run_id)
        self.autoregressive_forecast(run_id)
        print('\n\n[INFO] Training complete!')
        return self.best_val, self.best_mape