def find_lr(self,model, device, train_loader, lr_val=1e-8, decay=1e-2):
		criterion = nn.CrossEntropyLoss()
		optimizer = optim.SGD(model.parameters(), lr=lr_val, weight_decay=decay)
		lr_finder = LRFinder(model, optimizer, criterion, device)
		lr_finder.range_test(train_loader, end_lr=100, num_iter=100, step_mode="exp")
		lr_finder.plot()
		return lr_finder
Beispiel #2
0
def get_LR(model, trainloader, optimizer, criterion, device):

    print("########## Tweaked version from fastai ###########")
    lr_find = LRFinder(model, optimizer, criterion, device="cuda")
    lr_find.range_test(trainloader, end_lr=1, num_iter=100)
    lr_find.plot()  # to inspect the loss-learning rate graph
    lr_find.reset()
    for index in range(len(lr_find.history['loss'])):
        item = lr_find.history['loss'][index]
        if item == lr_find.best_loss:
            min_val_index = index
            print(f"{min_val_index}")

    lr_find.plot(show_lr=lr_find.history['lr'][75])
    lr_find.plot(show_lr=lr_find.history['lr'][min_val_index])

    val_index = 75
    mid_val_index = math.floor((val_index + min_val_index) / 2)
    show_lr = [{
        'data': lr_find.history['lr'][val_index],
        'linestyle': 'dashed'
    }, {
        'data': lr_find.history['lr'][mid_val_index],
        'linestyle': 'solid'
    }, {
        'data': lr_find.history['lr'][min_val_index],
        'linestyle': 'dashed'
    }]
    # lr_find.plot_best_lr(skip_start=10, skip_end=5, log_lr=True, show_lr=show_lr, ax=None)

    best_lr = lr_find.history['lr'][mid_val_index]
    print(f"LR to be used: {best_lr}")
    return best_lr
Beispiel #3
0
def lr_finder(model, optimizer, criterion, trainloader):
    lr_finder = LRFinder(model, optimizer, criterion, device="cuda")
    lr_finder.range_test(trainloader,
                         end_lr=100,
                         num_iter=100,
                         step_mode="exp")
    lr_finder.plot()  #to plot the loss vs Learning Rate curve
    lr_finder.reset()  # to reset the lr_finder
Beispiel #4
0
def lr_finder(model, train_loader):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    criterion = nn.CrossEntropyLoss()
    optimizer_ft = optim.Adam(model.parameters(), lr=0.0000001)
    lr_finder = LRFinder(model, optimizer_ft, criterion, device=device)
    lr_finder.range_test(train_loader, end_lr=1, num_iter=1000)
    lr_finder.reset()
    lr_finder.plot()
    def lr_finder(self, end_lr):

        lr_find = LRFinder(self.model, self.optimizer, self.criterion,
                           cfg.device)
        lr_find.range_test(self.data_loaders['val'],
                           end_lr=end_lr,
                           num_iter=2000)
        lr_find.plot()
Beispiel #6
0
def executeLr_finder(model, optimizer, device, trainloader, criterion):

    #finding and plotting the best LR
    lr_finder = LRFinder(model, optimizer, criterion, device="cuda")
    lr_finder.range_test(trainloader,
                         end_lr=100,
                         num_iter=100,
                         step_mode="exp")
    lr_finder.plot()  # to inspect the loss-learning rate graph

    lr_finder.reset(
    )  # to reset the model and optimizer to their initial state
Beispiel #7
0
def lr_finder(net, optimizer, loss_fun, trainloader, testloader):
    # Using LRFinder
    lr_finder = LRFinder(net, optimizer, loss_fun, device='cuda')
    lr_finder.range_test(trainloader,
                         val_loader=testloader,
                         start_lr=1e-3,
                         end_lr=0.1,
                         num_iter=100,
                         step_mode='exp')
    lr_finder.plot(log_lr=False)
    lr_finder.reset(
    )  # important to restore the model and optimizer's parameters to its initial state

    return lr_finder.history
def get_LR(model, trainloader, optimizer, criterion, device, testloader=None):

    # print("########## Tweaked version from fastai ###########")
    # lr_find = LRFinder(model, optimizer, criterion, device="cuda")
    # lr_find.range_test(trainloader, end_lr=100, num_iter=100)
    # best_lr=lr_find.plot()  # to inspect the loss-learning rate graph
    # lr_find.reset()
    # return best_lr

    # print("########## Tweaked version from fastai ###########")
    # lr_find = LRFinder(model, optimizer, criterion, device="cuda")
    # lr_find.range_test(trainloader, end_lr=1, num_iter=100)
    # lr_find.plot() # to inspect the loss-learning rate graph
    # lr_find.reset()
    # for index in range(len(lr_find.history['loss'])):
    #   item = lr_find.history['loss'][index]
    #   if item == lr_find.best_loss:
    #     min_val_index = index
    #     print(f"{min_val_index}")
    #
    # lr_find.plot(show_lr=lr_find.history['lr'][75])
    # lr_find.plot(show_lr=lr_find.history['lr'][min_val_index])
    #
    # val_index = 75
    # mid_val_index = math.floor((val_index + min_val_index)/2)
    # show_lr=[{'data': lr_find.history['lr'][val_index], 'linestyle': 'dashed'}, {'data': lr_find.history['lr'][mid_val_index], 'linestyle': 'solid'}, {'data': lr_find.history['lr'][min_val_index], 'linestyle': 'dashed'}]
    # # lr_find.plot_best_lr(skip_start=10, skip_end=5, log_lr=True, show_lr=show_lr, ax=None)
    #
    # best_lr = lr_find.history['lr'][mid_val_index]
    # print(f"LR to be used: {best_lr}")
    #
    # return best_lr

    print("########## Leslie Smith's approach ###########")
    lr_find = LRFinder(model, optimizer, criterion, device="cuda")
    lr_find.range_test(trainloader,
                       val_loader=testloader,
                       end_lr=1,
                       num_iter=100,
                       step_mode="linear")
    best_lr = lr_find.plot(log_lr=False)
    lr_find.reset()
    return best_lr
Beispiel #9
0
    #iaa.Sometimes(0.1, iaa.Grayscale(alpha=(0.0, 1.0), from_colorspace="RGB", name="grayscale")),
    # iaa.Sometimes(0.2, iaa.AdditiveLaplaceNoise(scale=(0, 0.1*255), per_channel=True, name="gaus-noise")),
    # Color, Contrast, etc.
    iaa.Sometimes(0.2, iaa.Multiply((0.75, 1.25), per_channel=0.1, name="brightness")),
    iaa.Sometimes(0.2, iaa.GammaContrast((0.7, 1.3), per_channel=0.1, name="contrast")),
    iaa.Sometimes(0.2, iaa.AddToHueAndSaturation((-20, 20), name="hue-sat")),
    iaa.Sometimes(0.3, iaa.Add((-20, 20), per_channel=0.5, name="color-jitter")),
])
augs_test = iaa.Sequential([
    # Geometric Augs
    iaa.Scale((imsize, imsize), 0),
])


db_train = AlphaPilotSegmentation(
    input_dir='data/dataset/train/images', label_dir='data/dataset/train/labels',
    transform=augs_train,
    input_only=["gaus-blur", "grayscale", "gaus-noise", "brightness", "contrast", "hue-sat", "color-jitter"],
    return_image_name=False
)


trainloader = DataLoader(db_train, batch_size=p['trainBatchSize'], shuffle=True, num_workers=32, drop_last=True)


# %matplotlib inline

lr_finder = LRFinder(net, optimizer, criterion, device="cuda")
lr_finder.range_test(trainloader, end_lr=1, num_iter=100)
lr_finder.plot()
# plt.show()
Beispiel #10
0
#                 else tensor.size(self.ch_dim)
#                 for i in range(2)]
# pad = torch.empty(padding_size, dtype=tensor.dtype).fill_(self.fill_value)
# tensor = torch.cat((tensor, pad), dim=self.len_dim)

# import os
# os.environ['MKL_NUM_THREADS'] = '1'
# import numpy as np
# import utils
#
#
# def main():
#     inpt = np.random.randint(-10, 10, 150000)
#
#     for i in range(1000):
#         out = utils.spectrogram(inpt, 256)
#
#
# if __name__ == '__main__':
#     main()

from lr_finder import LRFinder

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-5)
lr_finder = LRFinder(model, optimizer, criterion, device="cuda")
lr_finder.range_test(train_loader, end_lr=1, num_iter=50, step_mode="exp")
lr_finder.get_best_lr()
# lr_finder.plot()
# lr_finder.history
Beispiel #11
0
                                             shuffle=True, num_workers=4)
              for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")




# Finetuning the convnet
model = models.resnet18(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)

model = model.to(device)

criterion = nn.CrossEntropyLoss()

#Select a small learning rate for the start
optimizer_ft = optim.SGD(model.parameters(), lr=1e-5, momentum = 0.9)
lr_finder = LRFinder(model,optimizer_ft, criterion, device="cuda")
#Using the train loss
lr_finder.range_test(dataloaders['train'], end_lr=100,num_iter=1000,step_mode='exp')
lr_finder.plot()

#Using the validation loss
lr_finder.reset()
lr_finder.range_test(dataloaders['train'], val_loader=dataloaders['val'],end_lr=100,num_iter=200,step_mode='exp')
lr_finder.plot(skip_end=0)
  train_loader = DataLoader(train_ds,batch_size=batch_size, sampler=BalanceClassSampler(labels=train_ds.get_labels(), mode="downsampling"), shuffle=False, num_workers=4)
else:
  train_loader = DataLoader(train_ds,batch_size=batch_size, shuffle=True, num_workers=4)

plist = [
        {'params': model.backbone.parameters(),  'lr': learning_rate/50},
        {'params': model.meta_fc.parameters(),  'lr': learning_rate},
        # {'params': model.metric_classify.parameters(),  'lr': learning_rate},
    ]

optimizer = optim.Adam(plist, lr=learning_rate)
# lr_reduce_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=patience, verbose=True, threshold=1e-4, threshold_mode='rel', cooldown=0, min_lr=1e-7, eps=1e-08)
# cyclic_scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=learning_rate, max_lr=10*learning_rate, step_size_up=2000, step_size_down=2000, mode='triangular', gamma=1.0, scale_fn=None, scale_mode='cycle', cycle_momentum=False, base_momentum=0.8, max_momentum=0.9, last_epoch=-1)

criterion = criterion_margin_focal_binary_cross_entropy
if load_model:
  tmp = torch.load(os.path.join(model_dir, model_name+'_loss.pth'))
  model.load_state_dict(tmp['model'])
  # optimizer.load_state_dict(tmp['optim'])
  # lr_reduce_scheduler.load_state_dict(tmp['scheduler'])
  # cyclic_scheduler.load_state_dict(tmp['cyclic_scheduler'])
  # amp.load_state_dict(tmp['amp'])
  prev_epoch_num = tmp['epoch']
  best_valid_loss = tmp['best_loss']
  del tmp
  print('Model Loaded!')
# model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
lr_finder = LRFinder(model, optimizer, criterion, device="cuda")
lr_finder.range_test(train_loader, end_lr=100, num_iter=500,  accumulation_steps=accum_step)
lr_finder.plot() # to inspect the loss-learning rate graph
Beispiel #13
0
from lr_finder import LRFinder
from src.model_lib.MultiFTNet import MultiFTNet
from src.model_lib.MiniFASNet import MiniFASNetV1, MiniFASNetV2,MiniFASNetV1SE,MiniFASNetV2SE
from src.utility import get_kernel
from torch.nn import CrossEntropyLoss, MSELoss
from torch import optim
from src.data_io.dataset_loader import get_train_loader,get_eval_loader
from src.default_config import get_default_config, update_config
from train import parse_args
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
kernel_size = get_kernel(80, 60)
model = MultiFTNet(conv6_kernel = kernel_size)
cls_criterion = CrossEntropyLoss()
FT_criterion = MSELoss()
from torch import optim
# optimizer = optim.SGD(model.parameters(),
#                                    lr=0.1,
#                                    weight_decay=5e-4,
#                                    momentum=0.9)
optimizer = optim.AdamW(model.parameters())
lr_finder = LRFinder(model, optimizer, cls_criterion,FT_criterion)
conf = get_default_config()
args = parse_args()
conf = update_config(args, conf)
trainloader = get_train_loader(conf)
val_loader = get_eval_loader(conf)
lr_finder.range_test(trainloader, end_lr=1, num_iter=100, step_mode="linear")
lr_finder.plot(log_lr=False)
lr_finder.reset()
Beispiel #14
0
def train_loop(folds, fold):

    if CFG.device == 'GPU':
        LOGGER.info(f"========== fold: {fold} training ==========")

    # ====================================================
    # loader
    # ====================================================
    trn_idx = folds[folds['fold'] != fold].index
    val_idx = folds[folds['fold'] == fold].index

    train_folds = folds.loc[trn_idx].reset_index(drop=True)
    valid_folds = folds.loc[val_idx].reset_index(drop=True)
    valid_labels = valid_folds[CFG.target_cols].values

    train_dataset = TrainDataset(train_folds,
                                 transform=get_transforms(data='train'))
    valid_dataset = TrainDataset(valid_folds,
                                 transform=get_transforms(data='valid'))

    train_loader = DataLoader(train_dataset,
                              batch_size=CFG.batch_size,
                              shuffle=True,
                              num_workers=CFG.num_workers,
                              pin_memory=True,
                              drop_last=True)
    valid_loader = DataLoader(valid_dataset,
                              batch_size=CFG.batch_size * 2,
                              shuffle=False,
                              num_workers=CFG.num_workers,
                              pin_memory=True,
                              drop_last=False)

    # ====================================================
    # scheduler
    # ====================================================
    def get_scheduler(optimizer):
        if CFG.scheduler == 'ReduceLROnPlateau':
            scheduler = ReduceLROnPlateau(optimizer,
                                          mode='min',
                                          factor=CFG.factor,
                                          patience=CFG.patience,
                                          verbose=True,
                                          eps=CFG.eps)
        elif CFG.scheduler == 'CosineAnnealingLR':
            scheduler = CosineAnnealingLR(optimizer,
                                          T_max=CFG.T_max,
                                          eta_min=CFG.min_lr,
                                          last_epoch=-1)
        elif CFG.scheduler == 'CosineAnnealingWarmRestarts':
            scheduler = CosineAnnealingWarmRestarts(optimizer,
                                                    T_0=CFG.T_0,
                                                    T_mult=1,
                                                    eta_min=CFG.min_lr,
                                                    last_epoch=-1)
        return scheduler

    # ====================================================
    # model & optimizer
    # ====================================================

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model = CustomModel(CFG.model_name, pretrained=False)
    model = torch.nn.DataParallel(model)
    model.load_state_dict(
        torch.load(f'{CFG.model_name}_student_fold{fold}_best_score.pth',
                   map_location=torch.device('cpu'))['model'])
    # model.load_state_dict(torch.load(f'0.9647/{CFG.model_name}_no_hflip_fold{fold}_best_score.pth', map_location=torch.device('cpu'))['model'])
    model.to(device)

    # criterion = nn.BCEWithLogitsLoss()
    criterion = FocalLoss(alpha=1, gamma=6)

    # optimizer = Adam(model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay, amsgrad=False)
    optimizer = SGD(model.parameters(),
                    lr=1e-2,
                    weight_decay=CFG.weight_decay,
                    momentum=0.9)

    find_lr = False
    if find_lr:
        from lr_finder import LRFinder
        lr_finder = LRFinder(model, optimizer, criterion, device=device)
        lr_finder.range_test(train_loader,
                             start_lr=1e-2,
                             end_lr=1e0,
                             num_iter=100,
                             accumulation_steps=1)

        fig_name = f'{CFG.model_name}_lr_finder.png'
        lr_finder.plot(fig_name)
        lr_finder.reset()
        return
    scheduler = get_scheduler(optimizer)
    swa_model = torch.optim.swa_utils.AveragedModel(model)
    swa_scheduler = torch.optim.swa_utils.SWALR(optimizer, swa_lr=1e-3)
    swa_start = 9

    # ====================================================
    # loop
    # ====================================================

    best_score = 0.
    best_loss = np.inf

    for epoch in range(CFG.epochs):

        start_time = time.time()

        # train
        avg_loss = train_fn(train_loader, model, criterion, optimizer, epoch,
                            scheduler, device)

        # eval
        # avg_val_loss, preds, _ = valid_fn(valid_loader, model, criterion, device)
        if epoch > swa_start:
            swa_model.update_parameters(model)
            swa_scheduler.step()
        else:
            if isinstance(scheduler, ReduceLROnPlateau):
                scheduler.step(avg_val_loss)
            elif isinstance(scheduler, CosineAnnealingLR):
                scheduler.step()
            elif isinstance(scheduler, CosineAnnealingWarmRestarts):
                scheduler.step()

        # scoring
        avg_val_loss, preds, _ = valid_fn(valid_loader, model, criterion,
                                          device)
        score, scores = get_score(valid_labels, preds)

        elapsed = time.time() - start_time

        LOGGER.info(
            f'Epoch {epoch+1} - avg_train_loss: {avg_loss:.4f}  avg_val_loss: {avg_val_loss:.4f}  time: {elapsed:.0f}s'
        )
        LOGGER.info(
            f'Epoch {epoch+1} - Score: {score:.4f}  Scores: {np.round(scores, decimals=4)}'
        )

        if score > best_score:
            best_score = score
            LOGGER.info(
                f'Epoch {epoch+1} - Save Best Score: {best_score:.4f} Model')
            torch.save({'model': model.state_dict()}, OUTPUT_DIR +
                       f'{CFG.model_name}_no_hflip_fold{fold}_best_score.pth')

        # if avg_val_loss < best_loss:
        #     best_loss = avg_val_loss
        #     LOGGER.info(f'Epoch {epoch+1} - Save Best Loss: {best_loss:.4f} Model')
        #     torch.save({'model': model.state_dict(),
        #                 'preds': preds},
        #                 OUTPUT_DIR+f'{CFG.model_name}_fold{fold}_best_loss.pth')

    torch.optim.swa_utils.update_bn(train_loader, swa_model)
    avg_val_loss, preds, _ = valid_fn(valid_loader, swa_model, criterion,
                                      device)
    score, scores = get_score(valid_labels, preds)
    LOGGER.info(f'Save swa Score: {score:.4f} Model')
    torch.save({'model': swa_model.state_dict()},
               OUTPUT_DIR + f'swa_{CFG.model_name}_fold{fold}_{score:.4f}.pth')
    # if CFG.nprocs != 8:
    #     check_point = torch.load(OUTPUT_DIR+f'{CFG.model_name}_fold{fold}_best_score.pth')
    #     for c in [f'pred_{c}' for c in CFG.target_cols]:
    #         valid_folds[c] = np.nan
    #     try:
    #         valid_folds[[f'pred_{c}' for c in CFG.target_cols]] = check_point['preds']
    #     except:
    #         pass

    return
def main(args):
    # load train data into ram
    # data_path = '/mntlong/lanl_comp/data/'
    file_dir = os.path.dirname(__file__)
    data_path = os.path.abspath(os.path.join(file_dir, os.path.pardir, 'data'))
    train_info_path = os.path.join(data_path, 'train_info.csv')
    train_data_path = os.path.join(data_path, 'train_compressed.npz')

    train_info = pd.read_csv(train_info_path, index_col='Unnamed: 0')
    train_info['exp_len'] = train_info['indx_end'] - train_info['indx_start']

    train_signal = np.load(train_data_path)['signal']
    train_quaketime = np.load(train_data_path)['quake_time']

    # В валидацию берем 2 последних волны (части эксперимента)
    val_start_idx = train_info.iloc[-2, :]['indx_start']

    val_signal = train_signal[val_start_idx:]
    val_quaketime = train_quaketime[val_start_idx:]

    train_signal = train_signal[:val_start_idx]
    train_quaketime = train_quaketime[:val_start_idx]

    # training params
    window_size = 150000
    overlap_size = int(window_size * 0.5)
    num_bins = 17

    model = models.BaselineNetRawSignalCnnRnnV1(out_size=num_bins-1)
    loss_fn = nn.CrossEntropyLoss()  # L1Loss() SmoothL1Loss() MSELoss()

    # logs_path = '/mntlong/scripts/logs/'
    logs_path = os.path.abspath(os.path.join(file_dir, os.path.pardir, 'logs'))
    current_datetime = datetime.today().strftime('%b-%d_%H-%M-%S')
    log_writer_path = os.path.join(logs_path, 'runs',
                                   current_datetime + '_' + args.model_name)

    train_dataset = data.SignalDataset(train_signal, train_quaketime,
                                       num_bins=num_bins,
                                       idxs_wave_end=train_info['indx_end'].values,
                                       window_size=window_size,
                                       overlap_size=overlap_size)
    val_dataset = data.SignalDataset(val_signal, val_quaketime,
                                     num_bins=num_bins,
                                     idxs_wave_end=train_info['indx_end'].values,
                                     window_size=window_size,
                                     overlap_size=overlap_size)

    print('wave size:', train_dataset[0][0].size())

    train_loader = DataLoader(dataset=train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=5,
                              pin_memory=True)
    val_loader = DataLoader(dataset=val_dataset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=5,
                            pin_memory=True)

    if args.find_lr:
        from lr_finder import LRFinder
        optimizer = optim.Adam(model.parameters(), lr=1e-6)
        lr_find = LRFinder(model, optimizer, loss_fn, device='cuda')
        lr_find.range_test(train_loader, end_lr=1, num_iter=50, step_mode='exp')
        best_lr = lr_find.get_best_lr()
        lr_find.plot()
        lr_find.reset()
        print('best lr found: {:.2e}'.format(best_lr))
    else:
        best_lr = 3e-4

    optimizer = optim.Adam(model.parameters(), lr=best_lr)  # weight_decay=0.1
    lr_sched = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                    factor=0.5,
                                                    patience=3,
                                                    threshold=0.005)
    log_writer = SummaryWriter(log_writer_path)

    utils.train_clf_model(model=model, optimizer=optimizer, lr_scheduler=lr_sched,
                          train_loader=train_loader, val_loader=val_loader,
                          num_epochs=args.num_epochs, model_name=args.model_name,
                          logs_path=logs_path, log_writer=log_writer,
                          loss_fn=loss_fn, num_bins=num_bins)
Beispiel #16
0
def main(args):
    # load train data into ram
    # data_path = '/mntlong/lanl_comp/data/'
    file_dir = os.path.dirname(__file__)
    data_path = os.path.abspath(os.path.join(file_dir, os.path.pardir, 'data'))
    train_info_path = os.path.join(data_path, 'train_info.csv')
    train_data_path = os.path.join(data_path, 'train_compressed.npz')

    train_info = pd.read_csv(train_info_path, index_col='Unnamed: 0')
    train_info['exp_len'] = train_info['indx_end'] - train_info['indx_start']

    train_signal = np.load(train_data_path)['signal']
    train_quaketime = np.load(train_data_path)['quake_time']

    # В валидацию берем 2 последних волны (части эксперимента)
    val_start_idx = train_info.iloc[-2, :]['indx_start']

    val_signal = train_signal[val_start_idx:]
    val_quaketime = train_quaketime[val_start_idx:]

    train_signal = train_signal[:val_start_idx]
    train_quaketime = train_quaketime[:val_start_idx]

    # training params
    large_ws = 1500000
    overlap_size = int(large_ws * 0.5)
    small_ws = 150000
    num_bins = 17

    cpc_meta_model = models.CPCv1(out_size=num_bins - 1)

    # logs_path = '/mntlong/scripts/logs/'
    logs_path = os.path.abspath(os.path.join(file_dir, os.path.pardir, 'logs'))
    current_datetime = datetime.today().strftime('%b-%d_%H-%M-%S')
    log_writer_path = os.path.join(logs_path, 'runs',
                                   current_datetime + '_' + args.model_name)

    train_dataset = data.SignalCPCDataset(
        train_signal,
        train_quaketime,
        num_bins=num_bins,
        idxs_wave_end=train_info['indx_end'].values,
        large_ws=large_ws,
        overlap_size=overlap_size,
        small_ws=small_ws)
    val_dataset = data.SignalCPCDataset(
        val_signal,
        val_quaketime,
        num_bins=num_bins,
        idxs_wave_end=train_info['indx_end'].values,
        large_ws=large_ws,
        overlap_size=overlap_size,
        small_ws=small_ws)

    print('x_t size:', train_dataset[0][0].size())

    train_loader = DataLoader(dataset=train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=5,
                              pin_memory=True)
    val_loader = DataLoader(dataset=val_dataset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=5,
                            pin_memory=True)

    if args.find_lr:
        from lr_finder import LRFinder
        optimizer = optim.Adam(cpc_meta_model.parameters(), lr=1e-6)
        lr_find = LRFinder(cpc_meta_model,
                           optimizer,
                           criterion=None,
                           is_cpc=True,
                           device='cuda')
        lr_find.range_test(train_loader,
                           end_lr=2,
                           num_iter=75,
                           step_mode='exp')
        best_lr = lr_find.get_best_lr()
        lr_find.plot()
        lr_find.reset()
        print('best lr found: {:.2e}'.format(best_lr))
    else:
        best_lr = 3e-4
    # sys.exit()

    # model_path = os.path.join(logs_path, 'cpc_no_target_head_cont_last_state.pth')
    # cpc_meta_model.load_state_dict(torch.load(model_path)['model_state_dict'])
    # cpc_meta_model.to(torch.device('cuda'))

    optimizer = optim.Adam(cpc_meta_model.parameters(), lr=best_lr)
    # optimizer.load_state_dict(torch.load(model_path)['optimizer_state_dict'])
    lr_sched = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                    factor=0.5,
                                                    patience=3,
                                                    threshold=0.005)

    log_writer = SummaryWriter(log_writer_path)

    utils.train_cpc_model(cpc_meta_model=cpc_meta_model,
                          optimizer=optimizer,
                          num_bins=num_bins,
                          lr_scheduler=lr_sched,
                          train_loader=train_loader,
                          val_loader=val_loader,
                          num_epochs=args.num_epochs,
                          model_name=args.model_name,
                          logs_path=logs_path,
                          log_writer=log_writer)
def main(args):
    np.random.seed(432)
    torch.random.manual_seed(432)
    try:
        os.makedirs(args.outpath)
    except OSError:
        pass
    experiment_path = utils.get_new_model_path(args.outpath)
    print(experiment_path)

    train_writer = SummaryWriter(os.path.join(experiment_path, 'train_logs'))
    val_writer = SummaryWriter(os.path.join(experiment_path, 'val_logs'))
    trainer = train.Trainer(train_writer, val_writer)

    # todo: add config
    train_transform = data.build_preprocessing()
    eval_transform = data.build_preprocessing()

    trainds, evalds = data.build_dataset(args.datadir, None)
    trainds.transform = train_transform
    evalds.transform = eval_transform

    model = models.resnet34()
    opt = torch.optim.Adam(model.parameters(), lr=1e-8)

    trainloader = DataLoader(trainds,
                             batch_size=args.batch_size,
                             shuffle=True,
                             num_workers=8,
                             pin_memory=True)
    evalloader = DataLoader(evalds,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=16,
                            pin_memory=True)

    #find lr fast ai
    criterion = torch.nn.BCEWithLogitsLoss()
    lr_finder = LRFinder(model, opt, criterion, device="cuda")
    #     lr_finder.range_test(trainloader, val_loader=evalloader, end_lr=1, num_iter=10, step_mode="exp")
    lr_finder.range_test(trainloader,
                         end_lr=100,
                         num_iter=100,
                         step_mode="exp")

    #plot graph fast ai
    skip_start = 6
    skip_end = 3
    lrs = lr_finder.history["lr"]
    losses = lr_finder.history["loss"]
    grad_norm = lr_finder.history["grad_norm"]

    #     ind = grad_norm.index(min(grad_norm))
    #     opt_lr = lrs[ind]
    #     print('LR with min grad_norm =', opt_lr)

    lrs = lrs[skip_start:-skip_end]
    losses = losses[skip_start:-skip_end]

    fig = plt.figure(figsize=(12, 9))
    plt.plot(lrs, losses)
    plt.xscale("log")
    plt.xlabel("Learning rate")
    plt.ylabel("Loss")
    train_writer.add_figure('loss_vs_lr', fig)

    lr_finder.reset()

    #     fixed_lr = 1e-3
    fixed_lr = 3e-4
    opt = torch.optim.Adam(model.parameters(), lr=fixed_lr)

    #     #new
    #     lr = 1e-3
    #     eta_min = 1e-5
    #     t_max = 10
    #     opt = torch.optim.Adam(model.parameters(), lr=lr)
    #     scheduler = CosineAnnealingLR(opt, T_max=t_max, eta_min=eta_min)
    #     #new

    #     one cycle for 5 ehoches
    #     scheduler = CosineAnnealingLR(opt, 519*4, eta_min=1e-4)
    scheduler = CosineAnnealingLR(opt, args.epochs)

    #     scheduler = CosineAnnealingLR(opt, 519, eta_min=1e-5)
    #     scheduler = StepLR(opt, step_size=3, gamma=0.1)

    state_list = []
    for epoch in range(args.epochs):
        #         t = epoch / args.epochs
        #         lr = np.exp((1 - t) * np.log(lr_begin) + t * np.log(lr_end))
        # выставляем lr для всех параметров
        trainer.train_epoch(model, opt, trainloader, fixed_lr, scheduler)
        #         trainer.train_epoch(model, opt, trainloader, 3e-4, scheduler)
        #         trainer.train_epoch(model, opt, trainloader, 9.0451e-4, scheduler)
        metrics = trainer.eval_epoch(model, evalloader)

        state = dict(
            epoch=epoch,
            model_state_dict=model.state_dict(),
            optimizer_state_dict=opt.state_dict(),
            loss=metrics['loss'],
            lwlrap=metrics['lwlrap'],
            global_step=trainer.global_step,
        )
        state_copy = copy.deepcopy(state)
        state_list.append(state_copy)
        export_path = os.path.join(experiment_path, 'last.pth')
        torch.save(state, export_path)

    # save the best path
    best_export_path = os.path.join(experiment_path, 'best.pth')

    max_lwlrap = 0
    max_lwlrap_ind = 0
    for i in range(args.epochs):
        if state_list[i]['lwlrap'] > max_lwlrap:
            max_lwlrap = state_list[i]['lwlrap']
            max_lwlrap_ind = i

    best_state = state_list[max_lwlrap_ind]
    torch.save(best_state, best_export_path)
class Trainer:
    def __init__(self,
                 model,
                 criterion,
                 optimizer,
                 train_loader,
                 val_loader=None,
                 name="experiment",
                 experiments_dir="runs",
                 save_dir=None,
                 div_lr=1):
        self.device = device()
        self.model = model.to(self.device)
        self.criterion = criterion
        self.optimizer = optimizer
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.div_lr = div_lr
        self.update_lr(self.optimizer.defaults['lr'])
        self._epoch_count = 0
        self._best_loss = None
        self._best_acc = None
        if save_dir is None:
            save_dir = f"{self.get_num_dir(experiments_dir):04d}-{get_git_hash()}-{name}"
        self._save_dir = os.path.join(experiments_dir, save_dir)
        self.writer = Logger(self._save_dir)
        atexit.register(self.cleanup)

    def train(self, epochs=1):
        for epoch in range(epochs):
            self._epoch_count += 1
            print("\n----- epoch ", self._epoch_count, " -----")
            train_loss, train_acc = self._train_epoch()
            if self.val_loader:
                val_loss, val_acc = self._validate_epoch()
                if self._best_loss is None or val_loss < self._best_loss:
                    self.save_checkpoint('best_model')
                    self._best_loss = val_loss
                    print("new best val loss!")
                if self._best_acc is None or val_acc > self._best_acc:
                    self.save_checkpoint('best_model_acc')
                    self._best_acc = val_acc
                    print("new best val acc!")

    def test(self, test_loader):
        self.model.eval()
        running_loss = 0
        running_acc = 0
        for iter, (inputs, targets) in enumerate(tqdm(test_loader)):
            inputs = inputs.to(device())
            targets = targets.to(device())
            with torch.set_grad_enabled(False):
                outputs = self.model(inputs)
                batch_loss = self.criterion(outputs, targets)
                batch_acc = accuracy(outputs, targets)
            running_loss += batch_loss.item()
            running_acc += batch_acc.item()
        epoch_loss = running_loss / len(test_loader)
        epoch_acc = running_acc / len(test_loader)
        print(f"test loss: {epoch_loss:.5f} test acc: {epoch_acc:.5f}")
        return epoch_loss, epoch_acc

    def train_one_cycle(self, epochs=1, lr=None):
        if lr is None:
            lr = self.optimizer.defaults['lr']
        self.onecycle = OneCycle(len(self.train_loader) * epochs, lr)
        self.train(epochs)
        self.onecycle = None

    def _train_epoch(self, save_histogram=False):
        self.model.train()
        running_loss = 0
        running_acc = 0
        for iter, (inputs, targets) in enumerate(tqdm(self.train_loader)):
            inputs = inputs.to(device())
            targets = targets.to(device())
            if self.onecycle is not None:
                lr, mom = next(self.onecycle)
                self.update_lr(lr)
                self.update_mom(mom)
            with torch.set_grad_enabled(True):
                outputs = self.model(inputs)
                batch_loss = self.criterion(outputs, targets)
                batch_acc = accuracy(outputs, targets)
                batch_loss.backward()
                self.optimizer.step()
                self.optimizer.zero_grad()
            running_loss += batch_loss.item()
            running_acc += batch_acc.item()
            if self.log_every(iter):
                self.writer.add_scalars(
                    "loss", {"train_loss": running_loss / float(iter + 1)},
                    (self._epoch_count - 1) * len(self.train_loader) + iter)
                self.writer.add_scalars(
                    "acc", {"train_acc": running_acc / float(iter + 1)},
                    (self._epoch_count - 1) * len(self.train_loader) + iter)
        epoch_loss = running_loss / len(self.train_loader)
        epoch_acc = running_acc / len(self.train_loader)
        print(f"train loss: {epoch_loss:.5f} train acc: {epoch_acc:.5f}")
        return epoch_loss, epoch_acc

    def _validate_epoch(self):
        self.model.eval()
        running_loss = 0
        running_acc = 0
        for iter, (inputs, targets) in enumerate(tqdm(self.val_loader)):
            inputs = inputs.to(device())
            targets = targets.to(device())
            with torch.set_grad_enabled(False):
                outputs = self.model(inputs)
                batch_loss = self.criterion(outputs, targets)
                batch_acc = accuracy(outputs, targets)
            running_loss += batch_loss.item()
            running_acc += batch_acc.item()
            if self.log_every(iter):
                self.writer.add_scalars(
                    "loss", {"val_loss": running_loss / float(iter + 1)},
                    (self._epoch_count - 1) * len(self.val_loader) + iter)
                self.writer.add_scalars(
                    "acc", {"val_acc": running_acc / float(iter + 1)},
                    (self._epoch_count - 1) * len(self.val_loader) + iter)
        epoch_loss = running_loss / len(self.val_loader)
        epoch_acc = running_acc / len(self.val_loader)
        print(f"val loss: {epoch_loss:.5f} val acc: {epoch_acc:.5f}")
        return epoch_loss, epoch_acc

    def get_num_dir(self, path):
        num_dir = len(os.listdir(path))
        return num_dir

    def save_checkpoint(self, fname):
        path = os.path.join(self._save_dir, fname)
        torch.save(
            dict(
                epoch=self._epoch_count,
                best_loss=self._best_loss,
                best_acc=self._best_acc,
                model=self.model.state_dict(),
                optimizer=self.optimizer.state_dict(),
            ), path)

    def load_checkpoint(self, fname):
        path = os.path.join(self._save_dir, fname)
        checkpoint = torch.load(path,
                                map_location=lambda storage, loc: storage)
        self._epoch_count = checkpoint['epoch']
        self.model.load_state_dict(checkpoint['model'])
        self.optimizer.load_state_dict(checkpoint['optimizer'])

    def log_every(self, i):
        return (i % 100) == 0

    def update_lr(self, lr):
        n = len(self.optimizer.param_groups) - 1
        for i, g in enumerate(self.optimizer.param_groups):
            g['lr'] = lr / (self.div_lr**(n - i))

    def update_mom(self, mom):
        keys = self.optimizer.param_groups[0].keys()
        for g in self.optimizer.param_groups:
            if 'momentum' in g.keys():
                g['momentum'] = mom
            elif 'betas' in g.keys():
                g['betas'] = mom if isinstance(mom, tuple) else (mom,
                                                                 g['betas'][1])
            else:
                raise ValueError

    def find_lr(self, start_lr=1e-7, end_lr=100, num_iter=100):
        optimizer_state = self.optimizer.state_dict()
        self.update_lr(start_lr)
        self.lr_finder = LRFinder(self.model, self.optimizer, self.criterion,
                                  self.device)
        self.lr_finder.range_test(self.train_loader,
                                  end_lr=end_lr,
                                  num_iter=num_iter)
        self.optimizer.load_state_dict(optimizer_state)
        self.lr_finder.plot()

    def cleanup(self):
        copy_runpy(self._save_dir)
        path = os.path.join(self._save_dir, "./all_scalars.json")
        self.writer.export_scalars_to_json(path)
        self.writer.close()
        model = Model(args.model, n_class=1, pretrained=True).to(device)
        model = torch.nn.DataParallel(model)

        # criterion = nn.CrossEntropyLoss()
        criterion = nn.BCEWithLogitsLoss()

        optimizer = torch.optim.AdamW(model.parameters(),
                                      lr=args.lr,
                                      weight_decay=args.weight_decay)

        if args.find_lr:
            lr_finder = LRFinder(model, optimizer, criterion, device=device)
            lr_finder.range_test(trn_loader,
                                 start_lr=args.start_lr,
                                 end_lr=args.end_lr,
                                 num_iter=100,
                                 accumulation_steps=args.accum_iter)
            fig_name = 'lr_curve.png'
            lr_finder.plot(fig_name)
            lr_finder.reset()
            break

        scheduler = CosineAnnealingWarmRestarts(optimizer,
                                                T_0=epochs,
                                                T_mult=1,
                                                eta_min=1e-6)
        scaler = GradScaler()
        for epoch in range(epochs):
            train_one_epoch(fold,
                            epoch,