示例#1
0
    def __getitem__(self, index):
        file = self.files[index]

        if file == None:
            y = None
            if isMode(self.mode, 'e1_e2_e3'):
                x = (None, None, None, None)
            elif isMode(self.mode, 'e1_e2'):
                x = (None, None, None)
            elif isMode(self.mode, 'e1'):
                x = (None, None)

            if self.domain == 2:
                x = (x, None)
            return x, y

        x, y = extractData(
            file,
            input_key=self.input_key,
            mode=self.mode,
            domain=self.domain,
            factors=self.factors,
            # e1_matCode=self.e1_materialCode
        )
        return x, y
示例#2
0
def collate(mode='r', domain=0):
    if isMode(mode, 'e1_e2_e3'):
        return collateE1E2E3
    elif isMode(mode, 'e1_e2'):
        if domain == 0:
            return collateE1E2
        elif domain == 2:
            return collateE1E2Dom2
    elif isMode(mode, 'e1'):
        return collateE1
    elif isMode(mode, 'r'):
        return collateR
示例#3
0
    os.makedirs(os.path.dirname(EXPORT_EXCEL_FILENAME))

# Factors
DATA_FACTOR_ROOT = 'configs/data_factors'
data_factors = yaml.safe_load(
    open(os.path.join(DATA_FACTOR_ROOT, '{}.yml'.format(data_factors))))
data_factors = {key: float(val) for key, val in data_factors.items()}

# Extract data_factors
fr = data_factors['r']
fA = data_factors['A']
f_eps = data_factors['eps']
f_lambd = data_factors['lambd']

# Data
if isMode(mode, 'e1_e2'):
    f_name = 'dataGeneration/csv/Au_interpolated_1.csv'
    content = pd.read_csv(f_name)

    PREDICTION_FILE = "PredictionData/r_e1_e2.csv"
    r_e1_e2_data = pd.read_csv(PREDICTION_FILE)
    r_e1_e2_data = {
        key: r_e1_e2_data[key].values
        for key in r_e1_e2_data.keys()
    }
    r_e1_e2_data['r1'] = r_e1_e2_data['r1'] * 1e-9
    r_e1_e2_data['r2'] = r_e1_e2_data['r2'] * 1e-9

    lambd = 1e-9 * content['wl'].values
    e1e2_classes = [
        f'{e1_cls},{e2_cls}' for e1_cls in E1_CLASSES for e2_cls in E2_CLASSES
示例#4
0
文件: utils.py 项目: amritsaha607/BTP
def evaluate(model,
             loader,
             mode='default',
             verbose=1,
             rel_err_acc_meters=[1, 5, 10],
             abs_err_acc_meters=[1, 5, 10],
             e1_classes=None,
             domain=0):
    """
        Evaluate model on dataset
    """

    n = len(loader)
    y_tot, y_pred_tot = [], []
    if isMode(mode, 'e1_e2'):
        y_tot_e1e2, y_pred_tot_e1e2 = defaultdict(list), defaultdict(list)
    elif isMode(mode, 'e1'):
        y_tot_e1, y_pred_tot_e1 = defaultdict(list), defaultdict(list)

    # r1_idx, r2_idx = None, None
    # e1r_idx, e1i_idx, e3r_idx, e3i_idx = [None]*4

    r1_idx, r2_idx = 0, 1

    rel_err_acc_r1, rel_err_acc_r2, abs_err_acc_r1, abs_err_acc_r2 = [None] * 4

    if isMode(mode, 'e1_e2'):
        DEFAULT_KEY = 'default'
        e1e2_classes = [
            f'{e1_cls},{e2_cls}' for e1_cls in E1_CLASSES
            for e2_cls in E2_CLASSES
        ]
        rel_err_acc_r1_calculator = {
            DEFAULT_KEY: ErrAcc(mode='rel',
                                err=rel_err_acc_meters,
                                keyPrefix='r1')
        }
        rel_err_acc_r2_calculator = {
            DEFAULT_KEY: ErrAcc(mode='rel',
                                err=rel_err_acc_meters,
                                keyPrefix='r2')
        }
        abs_err_acc_r1_calculator = {
            DEFAULT_KEY: ErrAcc(mode='abs',
                                err=abs_err_acc_meters,
                                keyPrefix='r1')
        }
        abs_err_acc_r2_calculator = {
            DEFAULT_KEY: ErrAcc(mode='abs',
                                err=abs_err_acc_meters,
                                keyPrefix='r2')
        }
        for class_ in e1e2_classes:
            rel_err_acc_r1_calculator[f'{class_}'] = ErrAcc(
                mode='rel', err=rel_err_acc_meters, keyPrefix=f'{class_}_r1')
            rel_err_acc_r2_calculator[f'{class_}'] = ErrAcc(
                mode='rel', err=rel_err_acc_meters, keyPrefix=f'{class_}_r2')
            abs_err_acc_r1_calculator[f'{class_}'] = ErrAcc(
                mode='abs', err=abs_err_acc_meters, keyPrefix=f'{class_}_r1')
            abs_err_acc_r2_calculator[f'{class_}'] = ErrAcc(
                mode='abs', err=abs_err_acc_meters, keyPrefix=f'{class_}_r2')
    elif isMode(mode, 'e1'):
        DEFAULT_KEY = 'default'
        rel_err_acc_r1_calculator = {
            DEFAULT_KEY: ErrAcc(mode='rel',
                                err=rel_err_acc_meters,
                                keyPrefix='r1')
        }
        rel_err_acc_r2_calculator = {
            DEFAULT_KEY: ErrAcc(mode='rel',
                                err=rel_err_acc_meters,
                                keyPrefix='r2')
        }
        abs_err_acc_r1_calculator = {
            DEFAULT_KEY: ErrAcc(mode='abs',
                                err=abs_err_acc_meters,
                                keyPrefix='r1')
        }
        abs_err_acc_r2_calculator = {
            DEFAULT_KEY: ErrAcc(mode='abs',
                                err=abs_err_acc_meters,
                                keyPrefix='r2')
        }
        for class_ in e1_classes:
            rel_err_acc_r1_calculator[f'{class_}'] = ErrAcc(
                mode='rel', err=rel_err_acc_meters, keyPrefix=f'{class_}_r1')
            rel_err_acc_r2_calculator[f'{class_}'] = ErrAcc(
                mode='rel', err=rel_err_acc_meters, keyPrefix=f'{class_}_r2')
            abs_err_acc_r1_calculator[f'{class_}'] = ErrAcc(
                mode='abs', err=abs_err_acc_meters, keyPrefix=f'{class_}_r1')
            abs_err_acc_r2_calculator[f'{class_}'] = ErrAcc(
                mode='abs', err=abs_err_acc_meters, keyPrefix=f'{class_}_r2')
    else:
        rel_err_acc_r1_calculator = ErrAcc(mode='rel',
                                           err=rel_err_acc_meters,
                                           keyPrefix='r1')
        rel_err_acc_r2_calculator = ErrAcc(mode='rel',
                                           err=rel_err_acc_meters,
                                           keyPrefix='r2')
        abs_err_acc_r1_calculator = ErrAcc(mode='abs',
                                           err=abs_err_acc_meters,
                                           keyPrefix='r1',
                                           data_factor=100)
        abs_err_acc_r2_calculator = ErrAcc(mode='abs',
                                           err=abs_err_acc_meters,
                                           keyPrefix='r2',
                                           data_factor=100)

    model.eval()

    for batch_idx, (x, y) in enumerate(loader):

        # y = getLabel(y, mode=mode)

        if isMode(mode, 'e1_e2'):
            x, x_e1, x_e2 = x
        elif isMode(mode, 'e1'):
            x, x_e1 = x

        if torch.cuda.is_available():
            x = x.cuda()
            y = y.cuda()

        if isMode(mode, 'e1_e2'):
            y_pred = model(x, x_e1, x_e2)
        elif isMode(mode, 'e1'):
            y_pred = model(x, x_e1)
        else:
            y_pred = model(x)

        y = y.detach().cpu()
        y_pred = y_pred.detach().cpu()
        y_pred = transform_domain(y_pred, domain=domain, reverse_=True)

        # Calculate metric on the fly
        if isMode(mode, 'e1_e2'):
            rel_err_acc_r1_calculator['default'].feedData(
                y[:, r1_idx].numpy(), y_pred[:, r1_idx].numpy())
            rel_err_acc_r2_calculator['default'].feedData(
                y[:, r2_idx].numpy(), y_pred[:, r2_idx].numpy())
            abs_err_acc_r1_calculator['default'].feedData(
                y[:, r1_idx].numpy(), y_pred[:, r1_idx].numpy())
            abs_err_acc_r2_calculator['default'].feedData(
                y[:, r2_idx].numpy(), y_pred[:, r2_idx].numpy())

            rel_err_acc_r1_calculator[f'{x_e1},{x_e2}'].feedData(
                y[:, r1_idx].numpy(), y_pred[:, r1_idx].numpy())
            rel_err_acc_r2_calculator[f'{x_e1},{x_e2}'].feedData(
                y[:, r2_idx].numpy(), y_pred[:, r2_idx].numpy())
            abs_err_acc_r1_calculator[f'{x_e1},{x_e2}'].feedData(
                y[:, r1_idx].numpy(), y_pred[:, r1_idx].numpy())
            abs_err_acc_r2_calculator[f'{x_e1},{x_e2}'].feedData(
                y[:, r2_idx].numpy(), y_pred[:, r2_idx].numpy())
        elif isMode(mode, 'e1'):
            rel_err_acc_r1_calculator['default'].feedData(
                y[:, r1_idx].numpy(), y_pred[:, r1_idx].numpy())
            rel_err_acc_r2_calculator['default'].feedData(
                y[:, r2_idx].numpy(), y_pred[:, r2_idx].numpy())
            abs_err_acc_r1_calculator['default'].feedData(
                y[:, r1_idx].numpy(), y_pred[:, r1_idx].numpy())
            abs_err_acc_r2_calculator['default'].feedData(
                y[:, r2_idx].numpy(), y_pred[:, r2_idx].numpy())

            rel_err_acc_r1_calculator[f'{x_e1}'].feedData(
                y[:, r1_idx].numpy(), y_pred[:, r1_idx].numpy())
            rel_err_acc_r2_calculator[f'{x_e1}'].feedData(
                y[:, r2_idx].numpy(), y_pred[:, r2_idx].numpy())
            abs_err_acc_r1_calculator[f'{x_e1}'].feedData(
                y[:, r1_idx].numpy(), y_pred[:, r1_idx].numpy())
            abs_err_acc_r2_calculator[f'{x_e1}'].feedData(
                y[:, r2_idx].numpy(), y_pred[:, r2_idx].numpy())
        else:
            rel_err_acc_r1_calculator.feedData(y[:, r1_idx].numpy(),
                                               y_pred[:, r1_idx].numpy())
            rel_err_acc_r2_calculator.feedData(y[:, r2_idx].numpy(),
                                               y_pred[:, r2_idx].numpy())
            abs_err_acc_r1_calculator.feedData(y[:, r1_idx].numpy(),
                                               y_pred[:, r1_idx].numpy())
            abs_err_acc_r2_calculator.feedData(y[:, r2_idx].numpy(),
                                               y_pred[:, r2_idx].numpy())

        # Keep track of predictions vs labels
        y_tot.append(y.numpy())
        y_pred_tot.append(y_pred.numpy())
        if isMode(mode, 'e1_e2'):
            y_tot_e1e2[f'{x_e1},{x_e2}'].append(y.numpy())
            y_pred_tot_e1e2[f'{x_e1},{x_e2}'].append(y_pred.numpy())
        elif isMode(mode, 'e1'):
            y_tot_e1[x_e1].append(y.numpy())
            y_pred_tot_e1[x_e1].append(y_pred.numpy())

        if verbose:
            n_arrow = 50 * (batch_idx + 1) // n
            progress = "Evaluate - [{}>{}] ({}/{})".format(
                "=" * n_arrow, "-" * (50 - n_arrow), (batch_idx + 1), n)
            print(progress, end='\r')

    print()

    # Process total data [labels & predictions]
    y_tot, y_pred_tot = np.concatenate(y_tot), np.concatenate(y_pred_tot)
    if isMode(mode, 'e1_e2'):
        for class_ in e1e2_classes:
            y_tot_e1e2[class_] = np.concatenate(y_tot_e1e2[class_])
            y_pred_tot_e1e2[class_] = np.concatenate(y_pred_tot_e1e2[class_])
    elif isMode(mode, 'e1'):
        for class_ in e1_classes:
            y_tot_e1[class_] = np.concatenate(y_tot_e1[class_])
            y_pred_tot_e1[class_] = np.concatenate(y_pred_tot_e1[class_])

    # Format into r1 & r2
    r1s = [y_tot[:, 0], y_pred_tot[:, 0]]
    r2s = [y_tot[:, 1], y_pred_tot[:, 1]]
    if isMode(mode, 'e1_e2'):
        r1s_e1e2, r2s_e1e2 = {}, {}
        for class_ in e1e2_classes:
            r1s_e1e2[class_] = [
                y_tot_e1e2[class_][:, 0], y_pred_tot_e1e2[class_][:, 0]
            ]
            r2s_e1e2[class_] = [
                y_tot_e1e2[class_][:, 1], y_pred_tot_e1e2[class_][:, 1]
            ]
    elif isMode(mode, 'e1'):
        r1s_e1, r2s_e1 = {}, {}
        for class_ in e1_classes:
            r1s_e1[class_] = [
                y_tot_e1[class_][:, 0], y_pred_tot_e1[class_][:, 0]
            ]
            r2s_e1[class_] = [
                y_tot_e1[class_][:, 1], y_pred_tot_e1[class_][:, 1]
            ]

    # Prepare loggs to return
    loggs = []

    # Append predictions vs labels
    n = len(r1s[0])
    for i in range(n):
        logg = {
            'r1': r1s[0][i],
            'r1_pred': r1s[1][i],
            'r2': r2s[0][i],
            'r2_pred': r2s[1][i]
        }
        loggs.append(logg)
    if isMode(mode, 'e1_e2'):
        for class_ in e1e2_classes:
            n = len(r1s_e1e2[class_][0])
            for i in range(n):
                logg = {
                    f'r1_{class_}': r1s_e1e2[class_][0][i],
                    f'r1_pred_{class_}': r1s_e1e2[class_][1][i],
                    f'r2_{class_}': r2s_e1e2[class_][0][i],
                    f'r2_pred_{class_}': r2s_e1e2[class_][1][i],
                }
                loggs.append(logg)
    elif isMode(mode, 'e1'):
        for class_ in e1_classes:
            n = len(r1s_e1[class_][0])
            for i in range(n):
                logg = {
                    f'r1_{class_}': r1s_e1[class_][0][i],
                    f'r1_pred_{class_}': r1s_e1[class_][1][i],
                    f'r2_{class_}': r2s_e1[class_][0][i],
                    f'r2_pred_{class_}': r2s_e1[class_][1][i],
                }
                loggs.append(logg)

    # Append metric values into loggs
    err_acc = {}

    if isMode(mode, 'e1_e2'):
        for class_ in e1e2_classes + [DEFAULT_KEY]:
            rel_err_acc_r1 = rel_err_acc_r1_calculator[f'{class_}'].getAcc()
            rel_err_acc_r2 = rel_err_acc_r2_calculator[f'{class_}'].getAcc()
            abs_err_acc_r1 = abs_err_acc_r1_calculator[f'{class_}'].getAcc()
            abs_err_acc_r2 = abs_err_acc_r2_calculator[f'{class_}'].getAcc()

            err_acc.update(rel_err_acc_r1)
            err_acc.update(rel_err_acc_r2)
            err_acc.update(abs_err_acc_r1)
            err_acc.update(abs_err_acc_r2)
    elif isMode(mode, 'e1'):
        for class_ in e1_classes + [DEFAULT_KEY]:
            rel_err_acc_r1 = rel_err_acc_r1_calculator[f'{class_}'].getAcc()
            rel_err_acc_r2 = rel_err_acc_r2_calculator[f'{class_}'].getAcc()
            abs_err_acc_r1 = abs_err_acc_r1_calculator[f'{class_}'].getAcc()
            abs_err_acc_r2 = abs_err_acc_r2_calculator[f'{class_}'].getAcc()

            err_acc.update(rel_err_acc_r1)
            err_acc.update(rel_err_acc_r2)
            err_acc.update(abs_err_acc_r1)
            err_acc.update(abs_err_acc_r2)
    else:
        rel_err_acc_r1 = rel_err_acc_r1_calculator.getAcc()
        rel_err_acc_r2 = rel_err_acc_r2_calculator.getAcc()
        abs_err_acc_r1 = abs_err_acc_r1_calculator.getAcc()
        abs_err_acc_r2 = abs_err_acc_r2_calculator.getAcc()

        err_acc.update(rel_err_acc_r1)
        err_acc.update(rel_err_acc_r2)
        err_acc.update(abs_err_acc_r1)
        err_acc.update(abs_err_acc_r2)

    loggs.append(err_acc)

    return loggs
示例#5
0
def generateHybrid(model_id,
                   mode='r_e1',
                   domain=0,
                   version='v0',
                   in_dim=1761,
                   out_dim=2,
                   root='checkpoints'):

    # if isMode(mode, 'e1_e2_e3'):
    #     dataName = 'E1E2E3Data'
    #     model = E1E2E3Model()
    if isMode(mode, 'e1_e2'):
        dataName = 'E1E2Data'
        classes = [
            f'{e1_cls}_{e2_cls}' for e1_cls in E1_CLASSES
            for e2_cls in E2_CLASSES
        ]
        model = E1E2Model(E1_CLASSES,
                          E2_CLASSES,
                          model_id,
                          input_dim=in_dim,
                          out_dim=out_dim)
    elif isMode(mode, 'e1'):
        dataName = 'E1Data'
        classes = E1_CLASSES
        model = E1Model(E1_CLASSES,
                        model_id,
                        input_dim=in_dim,
                        out_dim=out_dim)

    ckpt_dir = os.path.join(root, f'domain_{domain}', mode, dataName,
                            str(model_id), version)

    print(f"ckpt dir : {ckpt_dir}")
    if isMode(mode, 'e1_e2'):

        if not os.path.exists(os.path.join(ckpt_dir, 'temp_parts')):
            os.makedirs(os.path.join(ckpt_dir, 'temp_parts'))

        for cls_ in classes:
            # Initialize model
            temp_model = E1E2Model(E1_CLASSES,
                                   E2_CLASSES,
                                   model_id,
                                   input_dim=in_dim,
                                   out_dim=out_dim)

            # Find checkpoint
            files_rx = os.path.join(ckpt_dir, f'e1e2_best_{cls_}_*.pth')
            files = glob.glob(files_rx)
            if len(files) == 0:
                raise ValueError(
                    f"No matching checkpoint found with regex {files_rx}")
            ckpt = files[0]

            # Load checkpoint
            temp_model.load_state_dict(
                torch.load(ckpt, map_location=torch.device('cpu')))

            # Save only class branch
            print(f"Saving {cls_} branch")
            e1_cls, e2_cls = cls_.split('_')
            torch.save(temp_model.model[e1_cls][e2_cls].state_dict(),
                       os.path.join(ckpt_dir, 'temp_parts', f'{cls_}.pth'))

        for cls_ in classes:
            e1_cls, e2_cls = cls_.split('_')
            ckpt_name = os.path.join(ckpt_dir, 'temp_parts', f'{cls_}.pth')
            model.model[e1_cls][e2_cls].load_state_dict(
                torch.load(ckpt_name, map_location=torch.device('cpu')))

        torch.save(model.state_dict(), os.path.join(ckpt_dir, 'hybrid.pth'))

        shutil.rmtree(os.path.join(ckpt_dir, 'temp_parts'))

    elif isMode(mode, 'e1'):

        if not os.path.exists(os.path.join(ckpt_dir, 'temp_parts')):
            os.makedirs(os.path.join(ckpt_dir, 'temp_parts'))

        for cls_ in classes:
            # Initialize Model
            temp_model = E1Model(E1_CLASSES,
                                 model_id,
                                 input_dim=in_dim,
                                 out_dim=out_dim)

            # Find checkpoint
            files_rx = os.path.join(ckpt_dir, f'e1_best_{cls_}_*.pth')
            files = glob.glob(files_rx)
            if len(files) == 0:
                raise ValueError(
                    f"No matching checkpoint found with regex {files_rx}")
            ckpt = files[0]

            # Load checkpoint
            temp_model.load_state_dict(
                torch.load(ckpt, map_location=torch.device('cpu')))

            # Save only class branch
            print(f"Saving {cls_} branch")
            torch.save(temp_model.model[cls_].state_dict(),
                       os.path.join(ckpt_dir, 'temp_parts', f'{cls_}.pth'))

        for cls_ in classes:
            ckpt_name = os.path.join(ckpt_dir, 'temp_parts', f'{cls_}.pth')
            model.model[cls_].load_state_dict(
                torch.load(ckpt_name, map_location=torch.device('cpu')))

        torch.save(model.state_dict(), os.path.join(ckpt_dir, 'hybrid.pth'))

        shutil.rmtree(os.path.join(ckpt_dir, 'temp_parts'))
示例#6
0
    def __init__(self,
                 root='dataGeneration/data/',
                 formats=['.csv'],
                 factors=None,
                 input_key='A_tot',
                 mode='r',
                 domain=0,
                 shuffle=True,
                 batch_size=None):

        if isMode(mode, 'e1') and batch_size == None:
            raise AssertionError(
                "Please provide batch_size for mode {}".format(mode))

        super(AreaDataset, self).__init__()

        self.files = []
        self.factors = factors
        self.input_key = input_key
        self.mode = mode
        self.domain = domain

        if isMode(self.mode, 'e1_e2_e3'):
            for format_ in formats:
                for e1_mat in os.listdir(root):
                    e1_root = os.path.join(root, e1_mat)
                    for e2_mat in os.listdir(e1_root):
                        e2_root = os.path.join(e1_root, e2_mat)
                        for e3_mat in os.listdir(e2_root):
                            e3_root = os.path.join(e2_root, e3_mat)
                            files = glob.glob(
                                os.path.join(e3_root, f"*{format_}"))
                            self.files += files

                            if shuffle:
                                random.shuffle(files)

                            if len(files) % batch_size != 0:
                                self.files += [
                                    None
                                ] * int(batch_size - len(files) % batch_size)

        elif isMode(self.mode, 'e1_e2'):
            for format_ in formats:
                for e1_mat in os.listdir(root):
                    e1_root = os.path.join(root, e1_mat)
                    for e2_mat in os.listdir(e1_root):
                        e2_root = os.path.join(e1_root, e2_mat)
                        files = glob.glob(os.path.join(e2_root, f"*{format_}"))
                        self.files += files

                        if shuffle:
                            random.shuffle(files)

                        if len(files) % batch_size != 0:
                            self.files += [None] * int(batch_size -
                                                       len(files) % batch_size)

        else:
            # Shuffling mode changed, data will be shuffled now
            # but material wise data will be in sequential order
            for format_ in formats:
                for material in os.listdir(root):
                    files = glob.glob(
                        os.path.join(root, material, '*{}'.format(format_)))
                    if shuffle:
                        random.shuffle(files)
                    self.files += files

                    # For e1 data, we'll have to add extra files (None) to fit it into batch_size
                    # So that multiple material samples doesn't get into single batch
                    if isMode(self.mode,
                              'e1') and (len(files) % batch_size != 0):
                        self.files += [None] * int(batch_size -
                                                   len(files) % batch_size)

        self.setLambda()
示例#7
0
文件: utils.py 项目: amritsaha607/BTP
def extractData(filename, 
            input_key='A_tot',
            mode='r',
            domain=0,
            e1_matCode=None,
            factors={'r': 1e9, 'eps': 1e7, 'lambd': 1e9, 'A': 1e17}):
    '''
        Extracts data from a csv file for training
        Args:
            filename    : filename to extract data from
                        [supported formats => '.csv']
    '''

    if not factors:
        f_r, f_eps, f_lambd, f_A = 1, 1, 1, 1
    else:
        f_r, f_eps, f_lambd, f_A = factors['r'], factors['eps'], factors['lambd'], factors['A']
    
    df = pd.read_csv(filename)

    # Input contains two columns, wavelength and area
    # x = np.c_[df['lambd'], df['A_tot']]

    # Input data is a list of combined wavelength and area

    # Sampled values of cross section at specified lambd interval
    x = f_A*df[input_key].values

    # First all wavelength data followed by area data
    # x = np.concatenate([f_lambd*df['lambd'].values, f_A*df['A_tot'].values], axis=0)

    # Output is radii values only
    y = np.array([
        f_r*df['r1'][0],
        f_r*df['r2'][0],
    ])

    if isMode(mode, 'e1_e2_e3'):
        # Extract e1_mat, e2_mat & e3_mat from filename
        info = filename.split('/')
        e1_mat, e2_mat, e3_mat = info[-4], info[-3], info[-2]
        x = (x, e1_mat, e2_mat, e3_mat)

    elif isMode(mode, 'e1_e2'):
        # Extract e1_mat & e2_mat from filename
        info = filename.split('/')
        e1_mat, e2_mat = info[-3], info[-2]
        x = (x, e1_mat, e2_mat)

    elif isMode(mode, 'e1'):
        mat = filename.split('/')[-2] # material name
        x = (x, mat) # pass e1_id of the material in input
        # y_e1 = oneHot(e1_matCode[mat], len(e1_matCode.keys()))


    if domain == 2:
        eps = {
            'e1': df['eps_1'].values,
            'e2': df['eps_2'].values,
            'e3': df['eps_3'].values,
        }
        x = (x, eps)

    return x, y
示例#8
0
    input_key=input_key,
    mode=mode,
    shuffle=True,
    batch_size=batch_size,
)
dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    num_workers=8,
    collate_fn=collate,
    drop_last=False,
)

# Samples
f = None
if isMode(mode, 'e1_e2'):
    f = glob.glob(os.path.join(data_root, '*', '*', '*.csv'))[0]
else:
    f = glob.glob(os.path.join(data_root, '*', '*.csv'))[0]
n_samples = pd.read_csv(f).values.shape[0]

# Model
# if mode=='default':
#     model_out_dim = 6+2*n_samples
# elif mode=='r':
#     model_out_dim = 2
# elif mode=="eps_sm":
#     model_out_dim = 4
# elif mode=='eps':
#     model_out_dim = 4+2*n_samples
# else:
示例#9
0
文件: train.py 项目: amritsaha607/BTP
def run():

    # Optimizer
    if optimizer_=='adam':
        optimizer = torch.optim.Adam(
            model.parameters(),
            lr=learning_rate,
            weight_decay=weight_decay,
            eps=adam_eps,
            amsgrad=adam_amsgrad
        )

    # Scheduler (optional)
    scheduler = None
    if 'scheduler' in configs:
        sch_factor = configs['scheduler']
        lr_lambda = lambda epoch: sch_factor**epoch
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

    # config = wandb.config
    loggs = []

    if not cont:
        config = {
            'version': version,
            'domain': domain,
            'model_ID': model_ID,
            'batch_size': batch_size,
            'n_epoch': n_epoch,
            'train_root': train_root,
            'val_root': val_root,
            'data_factors': data_factors,
            'optimizer': optimizer_,
            'lr': learning_rate,
            'weight_decay': weight_decay,
            'adam_eps': adam_eps,
            'amsgrad': adam_amsgrad,
            'CHECKPOINT_DIR': ckpt_dir,
            'scheduler': sch_factor if scheduler is not None else None,
            'cuda': torch.cuda.is_available(),
            'log_interval': 1,
        }
        # config.version = version
        # config.domain = domain
        # config.model_ID = model_ID
        # config.batch_size = batch_size
        # config.n_epoch = n_epoch
        # config.train_root = train_root
        # config.val_root = val_root
        # config.data_factors = args.data_factors
        # config.optimizer = optimizer_
        # config.lr = learning_rate
        # config.weight_decay = weight_decay
        # config.adam_eps = adam_eps
        # config.amsgrad = adam_amsgrad
        # config.CHECKPOINT_DIR = ckpt_dir
        # config.scheduler = sch_factor if scheduler is not None else None
        # config.cuda = torch.cuda.is_available()
        # config.log_interval = 1
    
    # Initialize wandb
    run_name = "train_{}_{}_dom{}".format(version, mode, domain)
    if log:
        if args.cont is not None:
            wandb.init(id=args.wid, name=run_name, 
                project=WANDB_PROJECT_NAME, dir=WANDB_PROJECT_DIR, resume=True)
        else:
            wandb.init(name=run_name, 
                project=WANDB_PROJECT_NAME, dir=WANDB_PROJECT_DIR,
                config=config)

        wandb.watch(model, log='all')


    BEST_LOSS = float('inf')

    if cont:
        BEST_LOSS = float(args.BEST_VAL_LOSS) if args.BEST_VAL_LOSS is not None else float('inf')
        if scheduler:
            print("Setting up scheduler to continuing state...\n")
            for epoch in range(1, cont+1):
                if epoch%10==0:
                    scheduler.step()

    topups = []
    # topups = ['loss_split_re']

    # Train & Validate over multiple epochs
    start_epoch = cont+1 if cont is not None else 1
    for epoch in range(start_epoch, n_epoch+1):

        print("Epoch {}".format(epoch))

        logg = {}

        logg_train = train(
            epoch,
            train_loader,
            optimizer,
            metrics=[],
            topups=topups,
            verbose=verbose
        )
        logg_val = validate(
            epoch,
            val_loader,
            metrics=[],
            topups=topups,
            verbose=verbose
        )

        logg.update(logg_train)
        logg.update(logg_val)

        if scheduler and epoch%10==0:
            scheduler.step()
            print("\nepoch {}, lr : {}\n"
                    .format(epoch, [param_group['lr'] for param_group in optimizer.param_groups]))

        loggs.append(logg)
        if log:
            wandb.log(logg, step=epoch)

        if save:
            if logg['val_loss'] < BEST_LOSS:
                BEST_LOSS = logg['val_loss']
                os.system('rm {}'.format(os.path.join(ckpt_dir, 'best_*.pth')))
                torch.save(model.state_dict(), os.path.join(ckpt_dir, 'best_{}.pth'.format(epoch)))
            # if epoch==n_epoch:
            os.system('rm {}'.format(os.path.join(ckpt_dir, 'latest_*.pth')))
            torch.save(model.state_dict(), os.path.join(ckpt_dir, 'latest_{}.pth'.format(epoch)))

            if isMode(mode, 'e1_e2_e3'):
                for e1_cls in E1_CLASSES:
                    for e2_cls in E2_CLASSES:
                        for e3_cls in E3_CLASSES:
                            cls_ = f'{e1_cls},{e2_cls},{e3_cls}'
                            if logg[f"val_loss_{cls_}"] < E1E2E3_BEST_LOSSES[cls_]:
                                E1E2E3_BEST_LOSSES[cls_] = logg[f"val_loss_{cls_}"]
                                os.system('rm {}'.format(os.path.join(ckpt_dir, f'e1e2_best_{e1_cls}_{e2_cls}_{e3_cls}_*')))
                                torch.save(model.state_dict(), os.path.join(ckpt_dir, f'e1e2_best_{e1_cls}_{e2_cls}_{e3_cls}_{epoch}.pth'))

            elif isMode(mode, 'e1_e2'):
                for e1_cls in E1_CLASSES:
                    for e2_cls in E2_CLASSES:
                        if logg[f"val_loss_{e1_cls},{e2_cls}"] < E1E2_BEST_LOSSES[f"{e1_cls},{e2_cls}"]:
                            E1E2_BEST_LOSSES[f"{e1_cls},{e2_cls}"] = logg[f"val_loss_{e1_cls},{e2_cls}"]
                            os.system('rm {}'.format(os.path.join(ckpt_dir, f'e1e2_best_{e1_cls}_{e2_cls}_*')))
                            torch.save(model.state_dict(), os.path.join(ckpt_dir, f'e1e2_best_{e1_cls}_{e2_cls}_{epoch}.pth'))

            elif isMode(mode, 'e1'):
                for e1_cls in E1_CLASSES:
                    if logg[f"val_loss_{e1_cls}"] < E1_BEST_LOSSES[e1_cls]:
                        E1_BEST_LOSSES[e1_cls] = logg[f"val_loss_{e1_cls}"]
                        os.system('rm {}'.format(os.path.join(ckpt_dir, f'e1_best_{e1_cls}_*')))
                        torch.save(model.state_dict(), os.path.join(ckpt_dir, f'e1_best_{e1_cls}_{epoch}.pth'))

    # Write the loggs to pickle
    pickle_name = os.path.join('cache', 'train',
                               f'dom_{domain}', mode, version.split('_')[0],
                               str(model_ID), '{}.pkl'.format(version.split('_')[1]))
    print(f"\nGenerating pickle file at : {pickle_name}")
    if not os.path.exists(os.path.dirname(pickle_name)):
        os.makedirs(os.path.dirname(pickle_name))
    with open(pickle_name, 'wb') as f:
        pickle.dump([dict(config), loggs], f)
示例#10
0
文件: train.py 项目: amritsaha607/BTP
def validate(epoch, loader, metrics=[], 
            verbose=1, topups=['loss_split_re']):

    """
        epoch : Epoch no
        loader : Validation dataloader
        metrics : metrics to log
    """

    if isMode(mode, 'e1_e2_e3'):
        e1e2e3_losses = defaultdict(float) # key => "<e1_cls>,<e2_cls>,<e3_cls>", value => loss
        e1e2e3_loss_counts = defaultdict(float) # key => "<e1_cls>,<e2_cls>,<e3_cls>"
    elif isMode(mode, 'e1_e2'):
        e1e2_losses = defaultdict(float) # key => "<e1_cls>,<e2_cls>", value => loss
        e1e2_loss_counts = defaultdict(float) # key => "<e1_cls>,<e2_cls>"
    elif isMode(mode, 'e1'):
        e1_losses = defaultdict(float)
        e1_loss_counts = defaultdict(float)

    n = len(loader)
    tot_loss, loss_count = 0.0, 0
    if 'loss_split_re' in topups:
        tot_loss_split = None

    model.eval()
    for batch_idx, (x, y) in enumerate(loader):

        # y = getLabel(y, mode=mode)
        
        # For domain 2, seperate eps first
        if domain == 2:
            x, eps = x

        # For e1, e2 modes, break x into parts
        if isMode(mode, 'e1_e2_e3'):
            x, x_e1, x_e2, x_e3 = x
        elif isMode(mode, 'e1_e2'):
            x, x_e1, x_e2 = x
        elif isMode(mode, 'e1'):
            x, x_e1 = x

        if torch.cuda.is_available():
            x = x.cuda()
            y = y.cuda()
            if domain == 2:
                eps = {key: val.cuda() for key, val in eps.items()}

        # y is needed only in domain 0 & 1
        if domain != 2:
            y = transform_domain(y, domain=domain)

        if isMode(mode, 'e1_e2_e3'):
            y_pred = model(x, x_e1, x_e2, x_e3)
        elif isMode(mode, 'e1_e2'):
            y_pred = model(x, x_e1, x_e2)
        elif isMode(mode, 'e1'):
            y_pred = model(x, x_e1)
        else:
            y_pred = model(x)

        # For domain 2, x_pred is all you need to get loss
        if domain == 2:

            # Convert to SI units
            r = {
                'r1': y_pred[:, 0] / data_factors['r'],
                'r2': y_pred[:, 1] / data_factors['r'],
            }
            x_pred = getArea(r, eps, val_set.lambd, ret=input_key).T * data_factors['A']

            # Extract peak Information, already data factor applied in cross section
            lambd_max_pred, A_max_pred = getPeakInfo(x_pred, val_set.lambd)
            lambd_max, A_max = getPeakInfo(x, val_set.lambd)

            # Apply data factors in lambda_max
            lambd_max = lambd_max * data_factors['lambd']
            lambd_max_pred = lambd_max_pred * data_factors['lambd']

            lambd_max, lambd_max_pred = lambd_max.unsqueeze(dim=1), lambd_max_pred.unsqueeze(dim=1)
            A_max, A_max_pred = A_max.unsqueeze(dim=1), A_max_pred.unsqueeze(dim=1)

            loss = criterion(y_pred, y,
                             mode=loss_mode,
                             run='val',
                             weights=1)
            loss += criterion(lambd_max_pred, lambd_max,
                              mode=loss_mode,
                              run='val',
                              weights=1)
            loss += criterion(A_max_pred, A_max,
                              mode='manhattan',
                              run='val',
                              weights=1)

        else:
            loss = criterion(y_pred, y,
                             mode=loss_mode,
                             run='val',
                             weights=loss_weights)

        if 'loss_split_re' in topups:
            loss_split = criterion(
                y_pred, y, 
                mode=loss_split_mode, run='val',
                weights=loss_weights
            )
            tot_loss_split = dictAdd([tot_loss_split, loss_split]) if tot_loss_split else loss_split

        if not math.isnan(loss.item()):
            tot_loss += loss.item()
            loss_count += 1

            if isMode(mode, 'e1_e2_e3'):
                e1e2e3_losses[f"val_loss_{x_e1},{x_e2},{x_e3}"] += loss.item()
                e1e2e3_loss_counts[f"val_loss_{x_e1},{x_e2},{x_e3}"] += 1
            elif isMode(mode, 'e1_e2'):
                e1e2_losses[f"val_loss_{x_e1},{x_e2}"] += loss.item()
                e1e2_loss_counts[f"val_loss_{x_e1},{x_e2}"] += 1
            elif isMode(mode, 'e1'):
                e1_losses[f"val_loss_{x_e1}"] += loss.item()
                e1_loss_counts[f"val_loss_{x_e1}"] += 1

        if verbose:
            n_arrow = 50*(batch_idx+1)//n
            progress = "Validation - [{}>{}] ({}/{}) loss : {:.4f}, avg_loss : {:.4f}".format(
                "="*n_arrow, "-"*(50-n_arrow), (batch_idx+1), n, loss.item(), tot_loss/loss_count
            )
            print(progress, end='\r')

    print()
    logg = {
        'val_loss': tot_loss/loss_count,
    }

    # Classwise loss of different materials with different e1, e2
    if isMode(mode, 'e1_e2_e3'):
        for key in e1e2e3_losses:
            e1e2e3_losses[key] /= e1e2e3_loss_counts[key]
        logg.update(e1e2e3_losses)
    elif isMode(mode, 'e1_e2'):
        for key in e1e2_losses:
            e1e2_losses[key] /= e1e2_loss_counts[key]
        logg.update(e1e2_losses)
    elif isMode(mode, 'e1'):
        for key in e1_losses:
            e1_losses[key] /= e1_loss_counts[key]
        logg.update(e1_losses)

    if 'loss_split_re' in topups:
        for key in tot_loss_split:
            tot_loss_split[key] /= loss_count
        logg.update(tot_loss_split)

    return logg