def train(data_loader):
    net = Linear(3, 1)
    lr = 1e-3
    optimizer = AdamW(net.parameters(), lr)
    scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=1, T_mult=2)
    max_epochs = 40
    for ep_id in range(max_epochs):
        net.train()
        for b_id, batch in enumerate(data_loader):
            optimizer.zero_grad()
            output = net(batch['data'].to(data_loader))
            CEloss = CrossEntropyLoss()
            loss = CEloss(output, batch['ground_truth'])
            writer = SummaryWriter(
                log_dir=
                'C:\Users\andre\OneDrive\Рабочий стол\train\checkpoints',
                comment="Batch loss")
            writer.add_graph(loss)
            loss.backward()
            optimiser.step()
        scheduler.step()
        torch.save(
            {
                'epoch': ep_id,
                'model_state_dict': net.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss,
            }, '/checkpoints/checkpoints.txt')
Example #2
0
class CosineAnnealRestartLR(object):
    def __init__(self, optimizer, args, start_iter):
        self.iter = start_iter
        self.max_iter = args['train.max_iter']
        self.lr_scheduler = CosineAnnealingWarmRestarts(
            optimizer, args['train.cosine_anneal_freq'], last_epoch=start_iter-1)
        # self.lr_scheduler = CosineAnnealingLR(
        #     optimizer, args['train.cosine_anneal_freq'], last_epoch=start_iter-1)

    def step(self, _iter):
        self.iter += 1
        self.lr_scheduler.step(_iter)
        stop_training = self.iter >= self.max_iter
        return stop_training
def run(net):
    folds = prepare_folds()

    trainloader, valloader = folds[int(hps['fold_id'])]

    net = net.to(device)

    scaler = GradScaler()
    optimizer = torch.optim.Adam(net.parameters(), lr=1e-4, weight_decay=1e-6)
    scheduler = CosineAnnealingWarmRestarts(optimizer,
                                            T_0=10,
                                            T_mult=1,
                                            eta_min=1e-6,
                                            last_epoch=-1)
    criterion = nn.CrossEntropyLoss().to(device)
    best_acc_v = 0

    print("Training", hps['name'], hps['fold_id'], "on", device)
    for epoch in range(hps['n_epochs']):

        acc_tr, loss_tr = train(net, trainloader, criterion, optimizer, scaler,
                                epoch + 1)
        acc_v, loss_v = evaluate(net, valloader, criterion, epoch + 1)

        # Update learning rate if plateau
        scheduler.step()

        # Save the best network and print results
        if acc_v > best_acc_v:
            save(net, hps, desc=hps['fold_id'])
            best_acc_v = acc_v

            print('Epoch %2d' % (epoch + 1),
                  'Train Accuracy: %2.2f %%' % acc_tr,
                  'Train Loss: %2.6f' % loss_tr,
                  'Val Accuracy: %2.2f %%' % acc_v,
                  'Val Loss: %2.6f' % loss_v,
                  'Network Saved',
                  sep='\t\t')

        else:
            print('Epoch %2d' % (epoch + 1),
                  'Train Accuracy: %2.2f %%' % acc_tr,
                  'Train Loss: %2.6f' % loss_tr,
                  'Val Accuracy: %2.2f %%' % acc_v,
                  'Val Loss: %2.6f' % loss_v,
                  sep='\t\t')
def main(opt):
    torch.manual_seed(opt.seed)

    if os.path.isdir("./artifacts_train"):
        log_versions = [
            int(name.split("_")[-1]) 
            for name in os.listdir(os.path.join("./artifacts_train", "logs")) 
            if os.path.isdir(os.path.join("./artifacts_train", "logs", name))
        ]
        current_version = f"version_{max(log_versions) + 1}"
    else:
        os.makedirs(os.path.join("./artifacts_train", "logs"), exist_ok=True)
        os.makedirs(os.path.join("./artifacts_train", "checkpoints"), exist_ok=True)
        current_version = "version_0"
    logger = SummaryWriter(logdir=os.path.join("./artifacts_train", "logs", current_version))
    os.makedirs(os.path.join("./artifacts_train", "checkpoints", current_version), exist_ok=True)

    device = torch.device("cuda:0")

    # Train Val Split
    path_to_train = os.path.join(opt.data_dir, "train_images/")
    train_df = pd.read_csv(os.path.join(opt.data_dir, "train.csv"))
    train_df["image_id"] = train_df["image_id"].apply(lambda x: os.path.join(path_to_train, x))
    train_df = train_df.sample(frac=1, random_state=opt.seed).reset_index(drop=True)
    train_df.columns = ["path", "label"]

    val_df = train_df.loc[int(len(train_df)*opt.train_split/100):].reset_index(drop=True)
    train_df = train_df.loc[:int(len(train_df)*opt.train_split/100)].reset_index(drop=True)

    # Augmentations
    train_trans = albu.Compose([
            albu.RandomResizedCrop(*opt.input_shape),
            albu.VerticalFlip(),
            albu.HorizontalFlip(),
            albu.Transpose(p=0.5),
            albu.ShiftScaleRotate(p=0.5),
            albu.HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.5),
            albu.RandomBrightnessContrast(brightness_limit=(-0.1,0.1), contrast_limit=(-0.1, 0.1), p=0.5),
            albu.CoarseDropout(p=0.5),
            albu.Cutout(p=0.5),
            albu.Normalize()
        ])

    val_trans = albu.Compose([
            albu.RandomResizedCrop(*opt.input_shape),
            albu.VerticalFlip(),
            albu.HorizontalFlip(),
            albu.ShiftScaleRotate(p=0.5),
            albu.HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.5),
            albu.RandomBrightnessContrast(brightness_limit=(-0.1,0.1), contrast_limit=(-0.1, 0.1), p=0.5),
            albu.Normalize() 
        ])

    # Dataset init
    data_train = LeafData(train_df, transforms=train_trans)
    data_val = LeafData(val_df, transforms=val_trans, tta=opt.tta)
    weights = get_weights(train_df)
    sampler_train = WeightedRandomSampler(weights, len(data_train))
    dataloader_train = DataLoader(data_train, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers)
    dataloader_val = DataLoader(data_val, shuffle=True, batch_size=8, num_workers=opt.num_workers)

    # Model init
    model = timm.create_model(opt.model_arch, pretrained=False)
    pretrained_path = get_pretrained(opt)
    model.load_state_dict(torch.load(pretrained_path, map_location=device))
    model.classifier = nn.Linear(model.classifier.in_features, 5)
    model.to(device)

    # freeze first opt.freeze_percent params
    param_count = len(list(model.parameters()))
    for param in list(model.parameters())[:int(param_count*opt.freeze_percent/100)]:
        param.requires_grad = False

    optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr, weight_decay=1e-6)
    scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=opt.num_epoch, T_mult=1, eta_min=1e-6, last_epoch=-1)
    criterion = get_loss(opt)

    best_acc = 0
    iteration_per_epoch = opt.iteration_per_epoch if opt.iteration_per_epoch else len(dataloader_train)
    for epoch in range(opt.num_epoch):
        # Train
        model.train()
        dataloader_iterator = iter(dataloader_train)
        pbar = tqdm(range(iteration_per_epoch), desc=f"Train : Epoch: {epoch + 1}/{opt.num_epoch}")
        
        for step in pbar:     
            try:
                images, labels = next(dataloader_iterator)
            except:
                dataloader_iterator = iter(dataloader_train)
                images, labels = next(dataloader_iterator)
            
            images = images.to(device)
            labels = labels.to(device)

            if opt.adversarial_attack:
                idx_for_attack = list(rng.choice(labels.size(0), size=labels.size(0) // 4))
                images[idx_for_attack] = pgd_attack(images[idx_for_attack], labels[idx_for_attack], model, criterion)
            
            logit = model(images)
            loss = criterion(logit, labels)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step(epoch + step/iteration_per_epoch)
            
            _, predicted = torch.max(logit.data, 1)
            accuracy = 100 * (predicted == labels).sum().item() / labels.size(0)
            pbar.set_postfix({"Accuracy": accuracy, "Loss": loss.cpu().data.numpy().item(), "LR": optimizer.param_groups[0]["lr"]})
            logger.add_scalar('Loss/Train', loss.cpu().data.numpy().item(), epoch*iteration_per_epoch + step + 1)
            logger.add_scalar('Accuracy/Train', accuracy, epoch*iteration_per_epoch + step + 1)
            logger.add_scalar('LR/Train', optimizer.param_groups[0]["lr"], epoch*iteration_per_epoch + step + 1)
        
        # Val
        print(f"Eval start! Epoch {epoch + 1}/{opt.num_epoch}")
        correct = 0
        total = 0
        loss_sum = 0

        model.eval()
        dataloader_iterator = iter(dataloader_val)
        pbar = tqdm(range(len(dataloader_val)), desc=f"Eval : Epoch: {epoch + 1}/{opt.num_epoch}")
        for step in pbar: 
            try:
                images, labels = next(dataloader_iterator)
            except:
                dataloader_iterator = iter(dataloader_val)
                images, labels = next(dataloader_iterator)

            labels = labels.to(device)

            if opt.tta:
                predicts = []
                loss_tta = 0
                for i in range(opt.tta):
                    img = images[i].to(device)
            
                    with torch.no_grad():
                        logit = model(img)

                    loss = criterion(logit, labels)
                    loss_tta += loss.cpu().data.numpy().item() / opt.tta
                    predicts.append(F.softmax(logit, dim=-1)[None, ...])
                
                predicts = torch.cat(predicts, dim=0).mean(dim=0)
                loss_sum += loss_tta

            else:
                images = images.to(device)
                with torch.no_grad():
                    logit = model(images)

                loss = criterion(logit, labels)
                predicts = F.softmax(logit, dim=-1)
                loss_sum += loss.cpu().data.numpy().item()

            #accuracy
            _, predicted = torch.max(predicts.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            pbar.set_postfix({"Accuracy": 100 * (predicted == labels).sum().item() / labels.size(0), "Loss": loss.cpu().data.numpy().item()})
            
        accuracy = 100 * correct / total
        loss_mean = loss_sum / len(dataloader_val)
        logger.add_scalar('Loss/Val', loss_mean, epoch*iteration_per_epoch + step + 1)
        logger.add_scalar('Accuracy/Val', accuracy, epoch*iteration_per_epoch + step + 1)
        print(f"Epoch: {epoch + 1}, Accuracy: {accuracy:.5f}, Loss {loss_mean:.5f}")
        if accuracy > best_acc:
            print("Saved checkpoint!")
            best_acc = accuracy
            torch.save({
                    "epoch": epoch + 1,
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "scheduler": scheduler.state_dict(),
                    "accuracy": round(accuracy, 5),
                    "loss": round(loss_mean, 5),
                    "config": opt,
                }, 
                os.path.join("./artifacts_train", "checkpoints", current_version, f"{epoch + 1}_accuracy_{accuracy:.5f}.pth"))
def train():
    epoch_size = len(train_dataset) // args.batch_size
    num_epochs = math.ceil(args.max_iter / epoch_size)

    df = pd.read_csv('./data/train.csv')
    tmp = np.sqrt(
        1 / np.sqrt(df['landmark_id'].value_counts().sort_index().values))
    margins = (tmp - tmp.min()) / (tmp.max() - tmp.min()) * 0.45 + 0.05

    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)

    scheduler = CosineAnnealingWarmRestarts(optimizer, num_epochs - 1)
    logger = open('log.txt', 'w')
    iteration = 0
    losses = AverageMeter()
    scores = AverageMeter()
    start_epoch = time.time()
    model.train()

    for epoch in range(num_epochs):
        if (epoch + 1) * epoch_size < iteration:
            continue
        if iteration == args.max_iter:
            break
        correct = 0
        start_time = time.time()
        input_size = 0

        for i, (inputs, targets) in enumerate(train_loader):
            inputs = inputs.cuda()
            targets = targets.to(device)
            with autocast():
                outputs = model(inputs)
                loss = criterion(outputs, targets, margins)
            # confs, preds = torch.max(outputs.detach(), dim=1)
            optimizer.zero_grad()
            # loss.backward()
            scaler.scale(loss).backward()
            # optimizer.step()
            scaler.step(optimizer)
            scaler.update()

            input_size += inputs.size(0)
            losses.update(loss.item(), inputs.size(0))
            scores.update(gap(preds, confs, targets))
            correct += (preds == targets).float().sum()

            iteration += 1

            writer.add_scalar('train_likelihood', losses.val, iteration)
            writer.add_scalar('validation_mse', scores.val, iteration)

            log = {
                'epoch': epoch + 1,
                'iteration': iteration,
                'loss': losses.val,
                'acc': correct.item() / len(train_dataset),
                'gap': scores.val
            }
            logger.write(str(log) + '\n')
            if iteration % args.verbose_eval == 0:
                print(
                    f'[{epoch+1}/{iteration}] Loss: {losses.val:.4f} Acc: {correct/input_size:.4f} GAP: {scores.val:.4f} LR: {scheduler.get_last_lr()} Time: {time.time() - start_time}'
                )

            if iteration % args.save_interval == 0:
                torch.save(
                    model.state_dict(),
                    f'ResNext101_448_300_{epoch+36}_{iteration+100000}.pth')

            scheduler.step(epoch + i / len(train_loader))

        print()

    logger.close()
    writer.close()
    print(time.time() - start_epoch)
Example #6
0
def main(opt):
    train_data, valid_data = get_train_valid_split_data_names(opt.img_folder, opt.ano_folder, valid_size=1/8)

    # データの読み込み
    print("load data")
    train_dataset = Phase1Dataset(train_data, load_size=(640, 640), augment=True, limit=opt.limit)
    print("train data length : %d" % (len(train_dataset)))
    valid_dataset = Phase1Dataset(valid_data, load_size=(640, 640), augment=False, limit=opt.limit)
    print("valid data length : %d" % (len(valid_dataset)))
    # DataLoaderの作成
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=opt.batch_size,
        shuffle=True,
        num_workers=opt.num_workers,
        pin_memory=True,
        drop_last=True
    )

    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=1,
        shuffle=False,
        num_workers=opt.num_workers,
        pin_memory=True,
        drop_last=True
    )

    # GPUの設定(PyTorchでは明示的に指定する必要がある)
    device = torch.device('cuda' if opt.gpus > 0 else 'cpu')

    # モデルの作成
    heads = {'hm': 1}
    model = get_pose_net(18, heads, 256).to(device)
    if opt.load_model != '':
        model, optimizer, start_epoch = load_model(
            model, opt.load_model, optimizer)

    # 最適化手法を定義
    if opt.optimizer == "SGD":
        optimizer = torch.optim.SGD(model.parameters(), lr=opt.lr)#, momentum=m, dampening=d, weight_decay=w, nesterov=n)
    elif opt.optimizer == "Adam":
        optimizer = torch.optim.Adam(model.parameters(), opt.lr)
    elif opt.optimizer == "RAdam":
        optimizer = optim.RAdam(model.parameters(), lr=opt.lr)
    
    # 損失関数を定義
    criterion = HMLoss()
    # 学習率のスケジューリングを定義
    scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=1, eta_min=0.00001)

    start_epoch = 0
    best_validation_loss = 1e10
    # 保存用フォルダの作成
    os.makedirs(os.path.join(opt.save_dir, opt.task, 'visualized'), exist_ok=True)

    # 学習 TODO エポック終了時点ごとにテスト用データで評価とモデル保存
    for epoch in range(start_epoch + 1, opt.num_epochs + 1):
        print("learning rate : %f" % scheduler.get_last_lr()[0])
        train(train_loader, model, optimizer, criterion, device, opt.num_epochs, epoch)
        if opt.optimizer == "SGD":
            scheduler.step()

        # 最新モデルの保存
        save_model(os.path.join(opt.save_dir, opt.task, 'model_last.pth'),
                   epoch, model, optimizer, scheduler)

        # テスト用データで評価
        validation_loss, accumulate_datas = valid(valid_loader, model, criterion, device)
        # ベストスコア更新でモデルの保存
        if validation_loss < best_validation_loss:
            best_validation_loss = validation_loss
            save_model(os.path.join(opt.save_dir, opt.task, 'model_best.pth'),
                       epoch, model, optimizer, scheduler)
            print("saved best model")
            visualization(os.path.join(opt.save_dir, opt.task, 'visualized'),
                        accumulate_datas)
def main():
    args = get_args()

    # archLoader
    arch_loader = ArchLoader(args.path)

    # Log
    log_format = '[%(asctime)s] %(message)s'
    logging.basicConfig(stream=sys.stdout,
                        level=logging.INFO,
                        format=log_format,
                        datefmt='%m-%d %I:%M:%S')
    t = time.time()
    local_time = time.localtime(t)
    if not os.path.exists('./log'):
        os.mkdir('./log')
    fh = logging.FileHandler(
        os.path.join('log/train-{}-{:02}-{:02}-{:.3f}'.format(
            local_time.tm_year % 2000, local_time.tm_mon, local_time.tm_mday,
            t)))
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)

    use_gpu = False
    if torch.cuda.is_available():
        use_gpu = True

    train_dataset, val_dataset = get_dataset('cifar100')

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.num_workers,
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.num_workers,
                                             pin_memory=True)

    model = mutableResNet20()

    logging.info('load model successfully')

    optimizer = torch.optim.SGD(get_parameters(model),
                                lr=args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    criterion_smooth = CrossEntropyLabelSmooth(1000, 0.1)

    if use_gpu:
        model = nn.DataParallel(model)
        loss_function = criterion_smooth.cuda()
        device = torch.device("cuda")
    else:
        loss_function = criterion_smooth
        device = torch.device("cpu")

    # scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
    #                                               lambda step: (1.0-step/args.total_iters) if step <= args.total_iters else 0, last_epoch=-1)
    scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=5)
    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    #     optimizer, T_max=200)

    model = model.to(device)

    all_iters = 0

    if args.auto_continue:  # 自动进行??
        lastest_model, iters = get_lastest_model()
        if lastest_model is not None:
            all_iters = iters
            checkpoint = torch.load(lastest_model,
                                    map_location=None if use_gpu else 'cpu')
            model.load_state_dict(checkpoint['state_dict'], strict=True)
            logging.info('load from checkpoint')
            for i in range(iters):
                scheduler.step()

    # 参数设置
    args.optimizer = optimizer
    args.loss_function = loss_function
    args.scheduler = scheduler
    args.train_loader = train_loader
    args.val_loader = val_loader

    if args.eval:
        if args.eval_resume is not None:
            checkpoint = torch.load(args.eval_resume,
                                    map_location=None if use_gpu else 'cpu')
            model.load_state_dict(checkpoint, strict=True)
            validate(model,
                     device,
                     args,
                     all_iters=all_iters,
                     arch_loader=arch_loader)
        exit(0)

    # warmup weights
    if args.warmup > 0:
        logging.info("begin warmup weights")
        while all_iters < args.warmup:
            all_iters = train_supernet(model,
                                       device,
                                       args,
                                       bn_process=False,
                                       all_iters=all_iters)

        validate(model,
                 device,
                 args,
                 all_iters=all_iters,
                 arch_loader=arch_loader)

    while all_iters < args.total_iters:
        logging.info("=" * 50)
        all_iters = train_subnet(model,
                                 device,
                                 args,
                                 bn_process=False,
                                 all_iters=all_iters,
                                 arch_loader=arch_loader)

        if all_iters % 200 == 0:
            logging.info("validate iter {}".format(all_iters))

            validate(model,
                     device,
                     args,
                     all_iters=all_iters,
                     arch_loader=arch_loader)
Example #8
0
best_val_loss, ea_patience = 1e10, 0
for epoch in range(max_epochs):
    losses = []
    model.train()    
    tk0 = tqdm(data_loader_train)
    torch.backends.cudnn.benchmark = True
    for step, batch in enumerate(tk0):
        inputs, labels = batch["image"].cuda().float(), batch["labels"].cuda().float()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        losses.append(loss.item())
        with amp.scale_loss(loss, optimizer) as scaled_loss:
            scaled_loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        scheduler.step()
        tk0.set_postfix({'loss':np.nanmean(losses)})
    logs.append(np.nanmean(losses))
    
    torch.backends.cudnn.benchmark = False
    tk0 = tqdm(data_loader_val)
    val_losses = []
    model.eval()
    with torch.no_grad():
        for step, batch in enumerate(tk0):
            inputs, labels = batch["image"].cuda().float(), batch["labels"].cuda().float()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_losses.append(loss.item())
            tk0.set_postfix({'val_loss':np.nanmean(val_losses)})
        val_logs.append(np.nanmean(val_losses))
Example #9
0
def train():
    epoch_size = len(trainset) // args.batch_size
    num_epochs = math.ceil(args.max_iter / epoch_size)
    start_epoch = 0
    iteration = 0

    df = pd.read_csv('./data/train.csv')
    tmp = np.sqrt(1 / np.sqrt(df['landmark_id'].value_counts().sort_index().values))
    margins = (tmp - tmp.min()) / (tmp.max() - tmp.min()) * 0.45 + 0.05

    print('Loading model...')
    model = EfficientNetLandmark(args.depth, args.num_classes)
    model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)
    scheduler = CosineAnnealingWarmRestarts(optimizer, num_epochs-1)

    if args.resume is not None:
        state_dict = torch.load(args.resume)
        try:
            print('Resume all state...')
            modules_state_dict = state_dict['modules']
            optimizer_state_dict = state_dict['optimizer']
            scheduler_state_dict = state_dict['scheduler']
            optimizer.load_state_dict(optimizer_state_dict)
            scheduler.load_state_dict(scheduler_state_dict)
            start_epoch = state_dict['epoch']
            iteration = state_dict['iteration']
        except KeyError:
            print('Resume only modules...')
            modules_state_dict = state_dict
        
        model_state_dict = {k.replace('module.', ''): v for k, v in modules_state_dict.items() if k.replace('module.', '') in model.state_dict().keys()}
        model.load_state_dict(model_state_dict)

    num_gpus = list(range(torch.cuda.device_count()))
    if len(num_gpus) > 1:
        print('Using data parallel...')
        model = nn.DataParallel(model, device_ids=num_gpus)

    # logger = open('log.txt', 'w')
    losses = AverageMeter()
    scores = AverageMeter()

    start_train = datetime.now()
    print(num_epochs, start_epoch, iteration)
    model.train()
    for epoch in range(start_epoch, num_epochs):
        if (epoch+1)*epoch_size < iteration:
            continue

        if iteration == args.max_iter:
            break
        
        correct = 0
        input_size = 0
        start_time = datetime.now()
        for i, (inputs, targets) in enumerate(train_loader):
            optimizer.zero_grad()

            inputs = inputs.to(device)
            targets = targets.to(device)

            with autocast():
                outputs = model(inputs)
                loss = criterion(outputs, targets, margins)
            
            confs, preds = torch.max(outputs.detach(), dim=1)
            # loss.backward()
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            # optimizer.step()
            scaler.update()

            losses.update(loss.item(), inputs.size(0))
            scores.update(gap(preds, confs, targets))
            correct += (preds == targets).float().sum()
            input_size += inputs.size(0)

            iteration += 1

            writer.add_scalar('loss', losses.val, iteration)
            writer.add_scalar('gap', scores.val, iteration)

            # log = {'epoch': epoch+1, 'iteration': iteration, 'loss': losses.val, 'acc': corrects.val, 'gap': scores.val}
            # logger.write(str(log) + '\n')
            if iteration % args.verbose_eval == 0:
                print(
                    f'[{epoch+1}/{iteration}] Loss: {losses.val:.5f} Acc: {correct/input_size:.5f}' \
                    f' GAP: {scores.val:.5f} LR: {optimizer.param_groups[0]["lr"]} Time: {datetime.now() - start_time}')
            
            if iteration > 100000 and iteration % args.save_interval == 0:
                print('Save model...')
                save_checkpoint({
                    'modules': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict(),
                    'epoch': epoch,
                    'iteration': iteration,
                }, f'effnet_b{args.depth}_{args.max_size}_{args.batch_size}_{epoch+1}_{iteration}.pth')

            scheduler.step(epoch+i / len(train_loader))
        print()

    # logger.close()
    writer.close()
    print(datetime.now() - start_train)
    loss_list = []
    auc = 0.0
    auc_list = [0]
    for epoch in tqdm(range(epoch_num)):
        epoch_loss = 0
        for batchX, batchY in video_loader:
            batchX = batchX.cuda()
            batchY = batchY.cuda()
            score_pred = net(batchX).cuda()
            batch_loss = loss_func(score_pred, batchY)
            epoch_loss += batch_loss
            reg_optimizer.zero_grad()
            batch_loss.backward()
            reg_optimizer.step()
        print('Epoch:{}/{} Loss:{}'.format(epoch + 1, epoch_num, epoch_loss))
        scheduler.step(epoch_loss)
        loss_list.append(epoch_loss)

        if (epoch + 1) % 10 == 0:
            net.eval()
            score_list = rn_predict(seg_dir=test_seg_dir,
                                    input_dim=input_dim,
                                    net=net)
            _, _, _, cur_auc = get_roc_metric(score_list=score_list,
                                              include_normal=True,
                                              anno_dir=anno_dir,
                                              path_dir=path_dir)
            _, _, _, oc_auc = get_roc_metric(score_list=score_list,
                                             include_normal=False,
                                             anno_dir=anno_dir,
                                             path_dir=path_dir)
Example #11
0
def main(args):

    args = parse_args()
    tag = args.tag
    device = torch.device('cuda:0')

    no_epochs = args.epochs
    batch_size = args.batch

    linear_hidden = args.linear
    conv_hidden = args.conv

    #Get train test paths -> later on implement cross val
    steps = get_paths(as_tuples=True, shuffle=True, tag=tag)
    steps_train, steps_test = steps[:int(len(steps) *
                                         .8)], steps[int(len(steps) * .2):]

    transform = transforms.Compose([
        DepthSegmentationPreprocess(no_data_points=args.no_data),
        ToSupervised()
    ])

    dataset_train = SimpleDataset(ids=steps_train,
                                  batch_size=batch_size,
                                  transform=transform,
                                  **SENSORS)
    dataset_test = SimpleDataset(ids=steps_test,
                                 batch_size=batch_size,
                                 transform=transform,
                                 **SENSORS)

    dataloader_params = {
        'batch_size': batch_size,
        'shuffle': True,
        'num_workers': 8
    }  #we've already shuffled paths

    dataset_train = DataLoader(dataset_train, **dataloader_params)
    dataset_test = DataLoader(dataset_test, **dataloader_params)

    batch = next(iter(dataset_test))
    action_shape = batch['action'][0].shape
    img_shape = batch['img'][0].shape
    #Nets
    actor_net = DDPGActor(img_shape=img_shape,
                          numeric_shape=[len(NUMERIC_FEATURES)],
                          output_shape=[2],
                          linear_hidden=linear_hidden,
                          conv_filters=conv_hidden)
    critic_net = DDPGCritic(actor_out_shape=action_shape,
                            img_shape=img_shape,
                            numeric_shape=[len(NUMERIC_FEATURES)],
                            linear_hidden=linear_hidden,
                            conv_filters=conv_hidden)

    print(len(steps))
    print(actor_net)
    print(get_n_params(actor_net))
    print(critic_net)
    print(get_n_params(critic_net))
    # save path
    actor_net_path = f'../data/models/offline/{DATE_TIME}/{actor_net.name}'
    critic_net_path = f'../data/models/offline/{DATE_TIME}/{critic_net.name}'
    os.makedirs(actor_net_path, exist_ok=True)
    os.makedirs(critic_net_path, exist_ok=True)
    optim_steps = args.optim_steps
    logging_idx = int(len(dataset_train.dataset) / (batch_size * optim_steps))

    actor_writer_train = SummaryWriter(f'{actor_net_path}/train',
                                       max_queue=30,
                                       flush_secs=5)
    critic_writer_train = SummaryWriter(f'{critic_net_path}/train',
                                        max_queue=1,
                                        flush_secs=5)
    actor_writer_test = SummaryWriter(f'{actor_net_path}/test',
                                      max_queue=30,
                                      flush_secs=5)
    critic_writer_test = SummaryWriter(f'{critic_net_path}/test',
                                       max_queue=1,
                                       flush_secs=5)

    #Optimizers
    actor_optimizer = torch.optim.Adam(actor_net.parameters(), lr=0.001)
    critic_optimizer = torch.optim.Adam(critic_net.parameters(), lr=0.001)

    actor_scheduler = CosineAnnealingWarmRestarts(actor_optimizer,
                                                  T_0=optim_steps,
                                                  T_mult=2)
    critic_scheduler = CosineAnnealingWarmRestarts(critic_optimizer,
                                                   T_0=optim_steps,
                                                   T_mult=2)
    #Loss function
    loss_function = torch.nn.MSELoss(reduction='sum')

    actor_best_train_loss = 1e10
    critic_best_train_loss = 1e10
    actor_best_test_loss = 1e10
    critic_best_test_loss = 1e10

    for epoch_idx in range(no_epochs):
        actor_train_loss = .0
        critic_train_loss = .0
        actor_running_loss = .0
        critic_running_loss = .0
        actor_avg_max_grad = .0
        critic_avg_max_grad = .0
        actor_avg_avg_grad = .0
        critic_avg_avg_grad = .0
        for idx, batch in enumerate(iter(dataset_train)):
            global_step = int((len(dataset_train.dataset) / batch_size *
                               epoch_idx) + idx)
            batch = unpack_batch(batch=batch, device=device)
            actor_loss, critic_loss, actor_grad, critic_grad = train_rl(
                batch=batch,
                actor_net=actor_net,
                critic_net=critic_net,
                actor_optimizer=actor_optimizer,
                critic_optimizer=critic_optimizer,
                loss_fn=loss_function)
            del batch
            gc.collect()

            actor_avg_max_grad += max(
                [element.max() for element in actor_grad])
            critic_avg_max_grad += max(
                [element.max() for element in critic_grad])
            actor_avg_avg_grad += sum(
                [element.mean() for element in actor_grad]) / len(actor_grad)
            critic_avg_avg_grad += sum(
                [element.mean() for element in critic_grad]) / len(critic_grad)

            actor_running_loss += actor_loss
            critic_train_loss += critic_loss
            actor_train_loss += actor_loss
            critic_running_loss += critic_loss

            actor_writer_train.add_scalar(tag=f'{actor_net.name}/running_loss',
                                          scalar_value=actor_loss / batch_size,
                                          global_step=global_step)
            actor_writer_train.add_scalar(tag=f'{actor_net.name}/max_grad',
                                          scalar_value=actor_avg_max_grad,
                                          global_step=global_step)
            actor_writer_train.add_scalar(tag=f'{actor_net.name}/mean_grad',
                                          scalar_value=actor_avg_avg_grad,
                                          global_step=global_step)

            critic_writer_train.add_scalar(
                tag=f'{critic_net.name}/running_loss',
                scalar_value=critic_loss / batch_size,
                global_step=global_step)
            critic_writer_train.add_scalar(tag=f'{critic_net.name}/max_grad',
                                           scalar_value=critic_avg_max_grad,
                                           global_step=global_step)
            critic_writer_train.add_scalar(tag=f'{critic_net.name}/mean_grad',
                                           scalar_value=critic_avg_avg_grad,
                                           global_step=global_step)

            if idx % logging_idx == logging_idx - 1:
                print(
                    f'Actor Epoch: {epoch_idx + 1}, Batch: {idx+1}, Loss: {actor_running_loss/logging_idx}'
                )
                print(
                    f'Critic Epoch: {epoch_idx + 1}, Batch: {idx+1}, Loss: {critic_running_loss/logging_idx}'
                )
                if (critic_running_loss /
                        logging_idx) < critic_best_train_loss:
                    critic_best_train_loss = critic_running_loss / logging_idx
                    torch.save(actor_net.state_dict(),
                               f'{actor_net_path}/train/train.pt')
                    torch.save(critic_net.state_dict(),
                               f'{critic_net_path}/train/train.pt')

                actor_writer_train.add_scalar(
                    tag=f'{actor_net.name}/lr',
                    scalar_value=actor_scheduler.get_last_lr()[0],
                    global_step=global_step)
                critic_writer_train.add_scalar(
                    tag=f'{critic_net.name}/lr',
                    scalar_value=critic_scheduler.get_last_lr()[0],
                    global_step=global_step)

                actor_scheduler.step()
                critic_scheduler.step()
                actor_running_loss = .0
                actor_avg_max_grad = .0
                actor_avg_avg_grad = .0
                critic_running_loss = .0
                critic_avg_max_grad = .0
                critic_avg_avg_grad = .0

        print(
            f'{critic_net.name} best train loss for epoch {epoch_idx+1} - {critic_best_train_loss}'
        )
        actor_writer_train.add_scalar(
            tag=f'{actor_net.name}/global_loss',
            scalar_value=(actor_train_loss / (len(dataset_train.dataset))),
            global_step=(epoch_idx + 1))
        critic_writer_train.add_scalar(
            tag=f'{critic_net.name}/global_loss',
            scalar_value=(critic_train_loss / (len(dataset_train.dataset))),
            global_step=(epoch_idx + 1))
        actor_test_loss = .0
        critic_test_loss = .0
        with torch.no_grad():
            for idx, batch in enumerate(iter(dataset_test)):
                batch = unpack_batch(batch=batch, device=device)
                q_pred = critic_net(**batch)
                action_pred = actor_net(**batch)
                critic_loss = loss_function(q_pred.view(-1),
                                            batch['q']).abs().sum()
                actor_loss = loss_function(action_pred,
                                           batch['action']).abs().sum()

                critic_test_loss += critic_loss
                actor_test_loss += actor_loss

        if critic_test_loss / len(
                dataset_test.dataset) < critic_best_test_loss:
            critic_best_test_loss = (critic_test_loss /
                                     len(dataset_test.dataset))
        if actor_test_loss / len(dataset_test.dataset) < actor_best_test_loss:
            actor_best_test_loss = (actor_test_loss /
                                    len(dataset_test.dataset))

        torch.save(critic_net.state_dict(),
                   f'{critic_net_path}/test/test_{epoch_idx+1}.pt')
        torch.save(actor_net.state_dict(),
                   f'{actor_net_path}/test/test_{epoch_idx+1}.pt')

        print(
            f'{critic_net.name} test loss {(critic_test_loss/len(dataset_test.dataset)):.3f}'
        )
        print(
            f'{actor_net.name} test loss {(actor_test_loss/len(dataset_test.dataset)):.3f}'
        )
        print(f'{critic_net.name} best test loss {critic_best_test_loss:.3f}')
        print(f'{actor_net.name} best test loss {actor_best_test_loss:.3f}')

        critic_writer_test.add_scalar(
            tag=f'{critic_net.name}/global_loss',
            scalar_value=(critic_test_loss / (len(dataset_test.dataset))),
            global_step=(epoch_idx + 1))
        actor_writer_test.add_scalar(
            tag=f'{actor_net.name}/global_loss',
            scalar_value=(actor_test_loss / (len(dataset_test.dataset))),
            global_step=(epoch_idx + 1))
        torch.cuda.empty_cache()
        gc.collect()

    torch.save(actor_optimizer.state_dict(),
               f=f'{actor_net_path}/{actor_optimizer.__class__.__name__}.pt')
    torch.save(critic_optimizer.state_dict(),
               f=f'{critic_net_path}/{critic_optimizer.__class__.__name__}.pt')
    json.dump(vars(args),
              fp=open(f'{actor_net_path}/args.json', 'w'),
              sort_keys=True,
              indent=4)
    json.dump(vars(args),
              fp=open(f'{critic_net_path}/args.json', 'w'),
              sort_keys=True,
              indent=4)

    actor_writer_train.flush()
    actor_writer_test.flush()
    actor_writer_train.close()
    actor_writer_test.close()
    critic_writer_train.flush()
    critic_writer_test.flush()
    critic_writer_train.close()
    critic_writer_test.close()

    batch = next(iter(dataset_test))
    batch = unpack_batch(batch=batch, device=device)

    #Actor architecture save
    y = actor_net(**batch)
    g = make_dot(y, params=dict(actor_net.named_parameters()))
    g.save(filename=f'{DATE_TIME}_{actor_net.name}.dot',
           directory=actor_net_path)
    #Critic architecture save
    y = critic_net(**batch)
    g = make_dot(y, params=dict(critic_net.named_parameters()))
    g.save(filename=f'{DATE_TIME}_{critic_net.name}.dot',
           directory=critic_net_path)

    check_call([
        'dot', '-Tpng', '-Gdpi=200',
        f'{critic_net_path}/{DATE_TIME}_{critic_net.name}.dot', '-o',
        f'{critic_net_path}/{DATE_TIME}_{critic_net.name}.png'
    ])
    check_call([
        'dot', '-Tpng', '-Gdpi=200',
        f'{actor_net_path}/{DATE_TIME}_{actor_net.name}.dot', '-o',
        f'{actor_net_path}/{DATE_TIME}_{actor_net.name}.png'
    ])
for i in range(N_EPOCHS):
    logger.new_epoch()
    # train
    classifier.train()

    epoch_trn_loss = []
    epoch_vld_loss = []
    epoch_vld_recall_g, epoch_vld_recall_v, epoch_vld_recall_c, epoch_vld_recall_all = [], [], [], []

    for j, (trn_imgs_batch, trn_lbls_batch) in enumerate(training_loader):
        # move to device
        trn_imgs_batch_device = trn_imgs_batch.cuda()
        trn_lbls_batch_device = trn_lbls_batch.cuda()

        # lr scheduler step
        lr_scheduler.step(i + j / nsteps)
        cur_lr = lr_scheduler.get_lr()

        # mixup
        trn_imgs_batch_device_mixup, trn_lbls_batch_device_shfl, gamma = mixup(
            trn_imgs_batch_device, trn_lbls_batch_device, 1.)

        # forward pass
        logits_g, logits_v, logits_c = classifier(trn_imgs_batch_device_mixup)

        loss_g = mixup_loss(logits_g, trn_lbls_batch_device[:, 0],
                            trn_lbls_batch_device_shfl[:, 0], gamma)
        loss_v = mixup_loss(logits_v, trn_lbls_batch_device[:, 1],
                            trn_lbls_batch_device_shfl[:, 1], gamma)
        loss_c = mixup_loss(logits_c, trn_lbls_batch_device[:, 2],
                            trn_lbls_batch_device_shfl[:, 2], gamma)
def main():
    args = get_args()
    num_gpus = torch.cuda.device_count()
    args.gpu = args.local_rank % num_gpus
    torch.cuda.set_device(args.gpu)

    torch.distributed.init_process_group(backend='nccl', init_method='env://')
    args.world_size = torch.distributed.get_world_size()
    args.batch_size = args.batch_size // args.world_size

    # archLoader
    arch_loader = ArchLoader(args.path)

    # Log
    log_format = '[%(asctime)s] %(message)s'
    logging.basicConfig(stream=sys.stdout,
                        level=logging.INFO,
                        format=log_format,
                        datefmt='%m-%d %I:%M:%S')
    t = time.time()
    local_time = time.localtime(t)
    if not os.path.exists('./log'):
        os.mkdir('./log')
    fh = logging.FileHandler(
        os.path.join('log/train-{}-{:02}-{:02}-{:.3f}'.format(
            local_time.tm_year % 2000, local_time.tm_mon, local_time.tm_mday,
            t)))
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)

    use_gpu = False
    if torch.cuda.is_available():
        use_gpu = True

    train_loader = get_train_loader(args.batch_size, args.local_rank,
                                    args.num_workers, args.total_iters)

    val_loader = get_val_loader(args.batch_size, args.num_workers)

    model = mutableResNet20()

    logging.info('load model successfully')

    optimizer = torch.optim.SGD(get_parameters(model),
                                lr=args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    criterion_smooth = CrossEntropyLabelSmooth(1000, 0.1)

    if use_gpu:
        # model = nn.DataParallel(model)
        model = model.cuda(args.gpu)
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)
        loss_function = criterion_smooth.cuda()
    else:
        loss_function = criterion_smooth

    scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=5)

    all_iters = 0

    if args.auto_continue:  # 自动进行??
        lastest_model, iters = get_lastest_model()
        if lastest_model is not None:
            all_iters = iters
            checkpoint = torch.load(lastest_model,
                                    map_location=None if use_gpu else 'cpu')
            model.load_state_dict(checkpoint['state_dict'], strict=True)
            logging.info('load from checkpoint')
            for i in range(iters):
                scheduler.step()

    # 参数设置
    args.optimizer = optimizer
    args.loss_function = loss_function
    args.scheduler = scheduler
    args.train_loader = train_loader
    args.val_loader = val_loader

    if args.eval:
        if args.eval_resume is not None:
            checkpoint = torch.load(args.eval_resume,
                                    map_location=None if use_gpu else 'cpu')
            model.load_state_dict(checkpoint, strict=True)
            validate(model, args, all_iters=all_iters, arch_loader=arch_loader)
        exit(0)

    # warmup weights
    if args.warmup > 0:
        logging.info("begin warmup weights")
        while all_iters < args.warmup:
            all_iters = train_supernet(model,
                                       args,
                                       bn_process=False,
                                       all_iters=all_iters)

        validate(model, args, all_iters=all_iters, arch_loader=arch_loader)

    while all_iters < args.total_iters:
        logging.info("=" * 50)
        all_iters = train_subnet(model,
                                 args,
                                 bn_process=False,
                                 all_iters=all_iters,
                                 arch_loader=arch_loader)

        if all_iters % 200 == 0 and args.local_rank == 0:
            logging.info("validate iter {}".format(all_iters))

            validate(model, args, all_iters=all_iters, arch_loader=arch_loader)
Example #14
0
def main():
    global args, best_performance

    set_seed(args.rand_seed)

    if args.model == 'FCNet':
        # dataloader
        train_loader, valid_loader, test_loader = get_FCNet_train_valid_test_loader(
            root=args.data_root,
            target=args.target,
            max_Miller=args.max_Miller,
            diffraction=args.diffraction,
            cell_type=args.cell_type,
            permute_hkl=args.fcnet_permute_hkl,
            randomize_hkl=args.fcnet_randomize_hkl,
            batch_size=args.batch_size,
            num_data_workers=args.num_data_workers)
        # construct model
        model = FCNet(max_Miller=args.max_Miller,
                      fc_dims=args.fcnet_fc_dims,
                      dropout=args.dropout)
    elif args.model == 'PointNet':
        # dataloader
        train_loader, valid_loader, test_loader = get_PointNet_train_valid_test_loader(
            root=args.data_root,
            target=args.target,
            max_Miller=args.max_Miller,
            diffraction=args.diffraction,
            cell_type=args.cell_type,
            randomly_scale_intensity=args.pointnet_randomly_scale_intensity,
            systematic_absence=args.pointnet_systematic_absence,
            batch_size=args.batch_size,
            num_data_workers=args.num_data_workers)
        # construct model
        model = PointNet(conv_filters=args.pointnet_conv_filters,
                         fc_dims=args.pointnet_fc_dims,
                         dropout=args.dropout)
    else:
        raise NotImplementedError

    # send model to device
    if torch.cuda.is_available():
        print('running on GPU:\n')
    else:
        print('running on CPU\n')
    model = model.to(args.device)

    # show number of trainable model parameters
    trainable_params = sum(p.numel() for p in model.parameters()
                           if p.requires_grad)
    print(
        'Number of trainable model parameters: {:d}'.format(trainable_params))

    # define loss function
    criterion = torch.nn.NLLLoss()

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    # HDFS
    if args.hdfs_dir is not None:
        os.system(f'hdfs dfs -mkdir -p {args.hdfs_dir}')

    # optionally resume from a checkpoint
    if args.restore_path != '':
        assert os.path.isfile(args.restore_path)
        print("=> loading checkpoint '{}'".format(args.restore_path),
              flush=True)
        checkpoint = torch.load(args.restore_path,
                                map_location=torch.device('cpu'))
        args.start_epoch = checkpoint['epoch'] + 1
        best_performance = checkpoint['best_performance']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> loaded checkpoint '{}' (epoch {})".format(
            args.restore_path, checkpoint['epoch']),
              flush=True)

    # learning-rate scheduler
    scheduler = CosineAnnealingWarmRestarts(optimizer=optimizer,
                                            T_0=args.epochs,
                                            eta_min=1E-8)

    print('\nStart training..', flush=True)
    for epoch in range(args.start_epoch, args.start_epoch + args.epochs):
        lr = scheduler.get_last_lr()
        logging.info('Epoch: {}, LR: {:.6f}'.format(epoch, lr[0]))

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        performance = validate(valid_loader, model, criterion)

        scheduler.step()

        # check performance
        is_best = performance > best_performance
        best_performance = max(performance, best_performance)

        # save checkpoint
        save_checkpoint(
            {
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'best_performance': best_performance,
                'optimizer': optimizer.state_dict(),
            }, is_best, args)

    # test best model
    print('---------Evaluate Model on Test Set---------------', flush=True)
    best_model = load_best_model()
    print('best validation performance: {:.3f}'.format(
        best_model['best_performance']))
    model.load_state_dict(best_model['state_dict'])
    validate(test_loader, model, criterion, test_mode=True)
Example #15
0
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3)

    def forward(self, x):
        pass


net_1 = model()

optimizer_1 = torch.optim.Adam(net_1.parameters(), lr=initial_lr)
scheduler_1 = CosineAnnealingWarmRestarts(optimizer_1, T_0=1)

print("初始化的学习率:", optimizer_1.defaults['lr'])

lr_list = []  # 把使用过的lr都保存下来,之后画出它的变化

for epoch in range(0, 6):
    # train
    for i in range(int(30000 / 32)):
        optimizer_1.zero_grad()
        optimizer_1.step()
        print("第%d个epoch的学习率:%f" % (epoch, optimizer_1.param_groups[0]['lr']))
        lr_list.append(optimizer_1.param_groups[0]['lr'])
        scheduler_1.step((epoch + i + 1) / int(30000 / 32))

# 画出lr的变化
plt.plot(lr_list)
plt.xlabel("epoch")
plt.ylabel("lr")
plt.title("learning rate's curve changes as epoch goes on!")
plt.show()
Example #16
0
model.train()
global_step = 0
epoch = 0
while True:
    # for j in range(50):
    for batch, (inp, target) in enumerate(dl):
        inp, target = inp.to(device), target.to(device)
        opt.zero_grad()

        out = model(inp)

        loss = F.cross_entropy(out, target)
        loss.backward()

        opt.step()
        sched.step()

        acc = accuracy(F.softmax(out), target)[0].item()
        loss_i = loss.item()

        print(
            f'{epoch:2}/{batch:3} Loss: {loss_i:.7f} Accuracy: {acc:.7f} LR: {sched.get_last_lr()[0]:.7f}'
        )
        wandb.log(
            {
                'loss': loss_i,
                'accuracy': acc,
                'lr': sched.get_last_lr()[0]
            },
            step=global_step)
Example #17
0
        g_loss_list.append(G_loss.item())
        d_loss_list.append((clean_loss.item() + noisy_loss.item()) / 2)

    G_LOSS = np.mean(np.array(g_loss_list))
    D_LOSS = np.mean(np.array(d_loss_list))

    generator_writer.add_scalar('Loss', G_LOSS, global_step=epoch)
    discriminator_writer.add_scalar('Loss', D_LOSS, global_step=epoch)
    generator_writer.add_scalar('LR',
                                g_optimizer.param_groups[0]['lr'],
                                global_step=epoch)
    discriminator_writer.add_scalar('LR',
                                    d_optimizer.param_groups[0]['lr'],
                                    global_step=epoch)

    g_lr_change.step()
    d_lr_change.step()

    # test_bar = tqdm(test_data_loader, desc='Test model')
    # generator = generator.cpu().eval()
    # discriminator = discriminator.cpu().eval()
    # G_test_loss = []
    # for test_clean, test_noisy in test_bar:
    #     z = torch.rand([test_noisy.size(0), 1024, 8])
    #     fake_speech = generator(test_noisy, z)
    #     outputs = discriminator(torch.cat((fake_speech, test_noisy), dim=1), test_ref_batch)
    #     l1_dist = torch.mean(torch.abs(torch.add(fake_speech, torch.neg(test_clean))))
    #     G_loss = 0.5 * torch.mean((outputs - 1.0) ** 2) + 100.0 * l1_dist
    #
    #     log = 'Epoch {}: test_G_loss {:.4f}'.format(epoch + 1, G_loss.data)
    #     test_bar.set_description(log)
Example #18
0
def main():
    global args, best_loss, weight_decay, momentum

    net = Network()

    epoch = 0
    saved = load_checkpoint()

    # 这里还需要做修改
    dataTrain = get_train_set(
        data_dir='C:/Users/hasee/Desktop/Master_Project/Step2/Plan_B/Label')
    dataVal = get_val_set(
        data_dir='C:/Users/hasee/Desktop/Master_Project/Step2/Plan_B/Label')

    train_loader = torch.utils.data.DataLoader(dataTrain,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=workers,
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(dataVal,
                                             batch_size=batch_size,
                                             shuffle=False,
                                             num_workers=workers,
                                             pin_memory=True)

    # 效果如果还可以的话可以考虑去掉weight_decay再试试看
    optimizer = torch.optim.SGD(net.parameters(),
                                lr,
                                momentum=momentum,
                                weight_decay=weight_decay)
    # 余弦退火
    if Cosine_lr:
        lr_scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=5)
    else:
        lr_scheduler = StepLR(optimizer, step_size=1, gamma=0.92)

    if doLoad:
        if saved:
            print('Loading checkpoint for epoch %05d ...' % (saved['epoch']))
            state = saved['model_state']
            try:
                net.module.load_state_dict(state)
            except:
                net.load_state_dict(state)
            epoch = saved['epoch']
            best_loss = saved['best_loss']
            optimizer = saved['optim_state']
            lr_scheduler = saved['scheule_state']
        else:
            print('Warning: Could not read checkpoint!')

    # Quick test
    if doTest:
        validate(val_loader, net, epoch)
        return
    '''
    for epoch in range(0, epoch):
        adjust_learning_rate(optimizer, epoch)
    '''

    m.begin_run(train_loader)
    print("start to run!")
    for epoch in range(epoch, epochs):

        # adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        train(train_loader, net, optimizer, epoch)
        lr_scheduler.step()

        # evaluate on validation set
        loss_total = validate(val_loader, net, epoch)

        # remember best loss and save checkpoint
        is_best = loss_total < best_loss
        best_loss = min(loss_total, best_loss)
        state = {
            'epoch': epoch + 1,
            'model_state': net.state_dict(),
            'optim_state': optimizer.state_dict(),
            'scheule_state': lr_scheduler.state_dict(),
            'best_loss': best_loss,
        }
        save_checkpoint(state, is_best)
    # 结束一次运行
    m.end_run()
Example #19
0
    }],
                        lr=model_config.learning_rate,
                        weight_decay=5e-4)
    scheduler = CosineAnnealingWarmRestarts(opt, T_0=1)
    # scheduler = CyclicLR(opt, base_lr=1e-4, max_lr=model_config.learning_rate, step_size_up=2000)

    for epoch in tqdm(range(model_config.epochs), desc='epochs'):
        tr_loss = 0

        encoder.train()
        decoder.train()

        for step, mb in tqdm(enumerate(tr_dl), desc='steps', total=len(tr_dl)):
            mb_loss = 0
            src_mb, tgt_mb = map(lambda elm: elm.to(device), mb)
            scheduler.step(epoch + step / len(tr_dl))
            opt.zero_grad()
            # encoder
            enc_outputs_mb, src_length_mb, enc_hc_mb = encoder(src_mb)

            # decoder
            dec_input_mb = torch.ones((tgt_mb.size()[0], 1),
                                      device=device).long()
            dec_input_mb *= tgt_vocab.to_indices(tgt_vocab.bos_token)
            dec_hc_mb = enc_hc_mb

            tgt_length_mb = tgt_mb.ne(
                tgt_vocab.to_indices(tgt_vocab.padding_token)).sum(dim=1)
            tgt_mask_mb = sequence_mask(tgt_length_mb, tgt_length_mb.max())

            use_teacher_forcing = True if random.random(
Example #20
0
def train(epochs: int=1,
          batch_size: int=800,
          model_name: str='resnet34',
          logdir: str='/tmp/loops/',
          lrs: tuple=(1e-4, 1e-3, 5e-3),
          eta_min: float=1e-6,
          dev_id: int=1,
          visdom_host: str='0.0.0.0',
          visdom_port: int=9001):

    vis = Visdom(server=visdom_host, port=visdom_port,
                 username=os.environ['VISDOM_USERNAME'],
                 password=os.environ['VISDOM_PASSWORD'])

    experiment_id = f'{model_name}_e{epochs}_b{batch_size}'
    device = torch.device(f'cuda:{dev_id}')
    # dataset = create_data_loaders(*load_data(), batch_size=batch_size)
    dataset = DataBunch().create(*load_data(), batch_size=batch_size)
    model = get_model(model_name, NUM_CLASSES).to(device)
    freeze_model(model)
    unfreeze_layers(model, ['conv1', 'bn1', 'layer4', 'last_linear'])

    loss_fn = nn.CrossEntropyLoss()
    conv, layer, head = lrs
    opt = torch.optim.AdamW([
        {'params': model.conv1.parameters(), 'lr': conv},
        {'params': model.layer4.parameters(), 'lr': layer},
        {'params': model.last_linear.parameters(), 'lr': head}
    ], weight_decay=0.01)
    logdir = os.path.join(logdir, experiment_id)
    sched = CosineAnnealingWarmRestarts(
        opt, T_0=len(dataset['train']), T_mult=2, eta_min=eta_min)
    rolling_loss = RollingLoss()
    os.makedirs(logdir, exist_ok=True)
    iteration = 0

    for epoch in range(1, epochs+1):
        trn_dl = dataset['train']
        n = len(trn_dl)

        model.train()
        with tqdm(total=n) as bar:
            for i, batch in enumerate(trn_dl, 1):
                iteration += 1
                if i % 25 == 0:
                    for j, g in enumerate(opt.param_groups):
                        vis.line(X=[iteration], Y=[g['lr']],
                                 win=f'metrics{j}', name=f'lr{j}', update='append')
                bar.set_description(f'[epoch:{epoch}/{epochs}][{i}/{n}]')
                opt.zero_grad()
                x = batch['features'].to(device)
                y = batch['targets'].to(device)
                out = model(x)
                loss = loss_fn(out, y)
                loss.backward()
                avg_loss = rolling_loss(loss.item(), iteration+1)
                opt.step()
                sched.step()
                bar.set_postfix(avg_loss=f'{avg_loss:.3f}')
                bar.update(1)
                vis.line(X=[iteration], Y=[avg_loss],
                         win='loss', name='avg_loss', update='append')

        val_dl = dataset['valid']
        n = len(val_dl)

        model.eval()
        with torch.no_grad():
            matches = []
            with tqdm(total=n) as bar:
                for batch in val_dl:
                    x = batch['features'].to(device)
                    y = batch['targets'].to(device)
                    out = model(x)
                    y_pred = out.softmax(dim=1).argmax(dim=1)
                    matched = (y == y_pred).detach().cpu().numpy().tolist()
                    matches.extend(matched)
                    bar.update(1)
            acc = np.mean(matches)
            vis.line(X=[epoch], Y=[acc], win='acc', name='val_acc', update='append')
            print(f'validation accuracy: {acc:2.2%}')
            acc_str = str(int(round(acc * 10_000, 0)))
            path = os.path.join(logdir, f'train.{epoch}.{acc_str}.pth')
            torch.save(model.state_dict(), path)
Example #21
0
class Trainer:
    def __init__(self, **kwargs):
        
        if kwargs['use_gpu']:
            print('Using GPU')
            self.device = t.device('cuda')
        else:
            self.device = t.device('cuda')

        # data
        if not os.path.exists(kwargs['root_dir'] + '/train_image_list'):
            image_list = glob(kwargs['root_dir'] + '/img_train/*')
            image_list = sorted(image_list)

            mask_list = glob(kwargs['root_dir'] + '/lab_train/*')
            mask_list = sorted(mask_list)

        else:
            image_list = open(kwargs['root_dir'] + '/train_image_list', 'r').readlines()
            image_list = [line.strip() for line in image_list]
            image_list = sorted(image_list)
            mask_list = open(kwargs['root_dir'] + '/train_label_list', 'r').readlines()
            mask_list = [line.strip() for line in mask_list]
            mask_list = sorted(mask_list)

        print(image_list[-5:], mask_list[-5:])

        if kwargs['augs']:
            augs = Augment(get_augmentations())
        else:
            augs = None

        self.train_loader, self.val_loader = build_loader(image_list, mask_list, kwargs['test_size'], augs, preprocess, \
            kwargs['num_workers'], kwargs['batch_size'])

        self.model = build_model(kwargs['in_channels'], kwargs['num_classes'], kwargs['model_name']).to(self.device)

        if kwargs['resume']:
            try:
                self.model.load_state_dict(t.load(kwargs['resume']))
                
            except Exception as e:
                self.model.load_state_dict(t.load(kwargs['resume'])['model'])

            print(f'load model from {kwargs["resume"]} successfully')
        if kwargs['loss'] == 'CE':
            self.criterion = nn.CrossEntropyLoss().to(self.device)
        elif kwargs['loss'] == 'FL':
            self.criterion = FocalLoss().to(self.device)
        else:
            raise NotImplementedError
        if kwargs['use_sgd']:
            self.optimizer = SGD(self.model.parameters(), lr=kwargs['lr'], momentum=kwargs['momentum'], nesterov=True, weight_decay=float(kwargs['weight_decay'])) 

        else:
            self.optimizer = Adam(self.model.parameters(), lr=kwargs['lr'], weight_decay=float(kwargs['weight_decay']))

        self.lr_planner = CosineAnnealingWarmRestarts(self.optimizer, 100, T_mult=2, eta_min=1e-6, verbose=True)

        log_dir = os.path.join(kwargs['log_dir'], datetime.now().strftime("%Y-%m-%d-%H-%M-%S"))

        self.writer = SummaryWriter(log_dir, comment=f"LR-{kwargs['lr']}_BatchSize-{kwargs['batch_size']}_ModelName-{kwargs['model_name']}")

        self.name = kwargs['model_name']
        self.epoch = kwargs['epoch']

        self.start_epoch = kwargs['start_epoch']

        self.val_interval = kwargs['val_interval']

        self.log_interval = kwargs['log_interval']

        self.num_classes = kwargs['num_classes']

        self.checkpoints = kwargs['checkpoints']

        self.color_dicts = kwargs['color_dicts']


        # , format='%Y-%m-%d %H:%M:%S',
        logging.basicConfig(filename=log_dir + '/log.log', level=logging.INFO, datefmt='%Y-%m-%d %H:%M:%S')

        s = '\n\t\t'
        for k, v in kwargs.items():
            s += f"{k}: \t{v}\n\t\t"

        logging.info(s)

    def train(self):
        max_miou = 0
        for epoch in range(self.start_epoch, self.epoch):
            logging.info(f'Train Epoch-{epoch}',)
            self.model.train()
            history = self.step(epoch)

            logging.info(f'Loss: {history[0]}; MIOU: {history[1]}, CareMIOU: {history[1]}')

            # write to tensorboard, log it
            logging.info(f'Val Epoch - {epoch}')
            history = self.eval()
            logging.info(f'MIOU: {history[0]}; CareMIOU: {history[1]}')
            self.writer.add_scalar('MIOU/Val', history[0], epoch)

            # 保存最好模型
            if history[0] > max_miou:
                self.save_model(f'{self.name}-{epoch}_MIOU-{history[0]}_CareMIOU-{history[1]}.pth')
                max_miou = history[0]
            

    def step(self, epoch):

        total_loss = 0
        correct = 0
        count = 0
        care_total = 0
        care_correct = 0
        tbar = tqdm(enumerate(self.train_loader), desc=f'Epoch:{epoch}', unit='batch', total=len(self.train_loader))
        for it, (image, mask) in tbar:
            image = image.to(self.device)
            mask = mask.to(self.device)
            count += mask.numel()
            care_mask = mask < 7
            care_total += care_mask.sum().item()

            self.optimizer.zero_grad()

            pred = self.model(image)

            correct += (pred.argmax(1) == mask).sum().item()
            care_correct += (pred.argmax(1)[care_mask] == mask[care_mask]).sum().item()

            loss = self.criterion(pred, mask)

            total_loss += loss.item()

            loss.backward()

            self.optimizer.step()

            tbar.set_description(f'[Iter/Total: {it}/{len(self.train_loader)}, Loss: {round(total_loss / (it + 1), 4)}, Accuracy/IOU: {round(float(correct) / count, 4)}, Care-IOU: {round(float(care_correct) / care_total, 4)}]')
            global_step = (epoch - self.start_epoch) * len(self.train_loader) + it
            if global_step % self.log_interval == 0:
                
                self.writer.add_scalar('Loss/train', total_loss / (it + 1), global_step)
                self.writer.add_scalar('Lr/Train', self.optimizer.param_groups[0]['lr'], global_step)
                self.writer.add_scalar('MIOU/Train', float(correct) / count, global_step)
                self.writer.add_scalar('Care-MIOU/Train', float(care_correct) / care_total, global_step)
                self.lr_planner.step()

            # if (it + 1) % (self.log_interval * 10) == 0:
            if global_step % (self.log_interval * 10) == 0:
                global_step = (epoch - self.start_epoch) * len(self.train_loader) + it
                if image.shape[0] < 4:
                    mask = vis_mask(mask.cpu(), self.color_dicts)
                    self.writer.add_image('Masks/True', make_grid(mask, nrow=2, padding=10), global_step)
                
                    pred = vis_mask(pred.argmax(1).cpu(), self.color_dicts)
                    self.writer.add_image('Masks/Pred', make_grid(pred, nrow=2, padding=10), global_step)  
                else:
                    mask = mask.cpu()
                    idx = t.randint(0, mask.shape[0], (4, ))
                    mask = vis_mask(mask.index_select(0, idx), self.color_dicts)
                    self.writer.add_image('Masks/True', make_grid(mask, nrow=2, padding=10), global_step)

                    pred = vis_mask(pred.argmax(1).cpu().index_select(0, idx), self.color_dicts)
                    self.writer.add_image('Masks/Pred', make_grid(pred, nrow=2, padding=10), global_step)  

        return total_loss / len(self.train_loader), float(correct) / count, float(care_correct) / care_total


    def save_model(self, save_name, save_opt=True):
        dicts = {}

        dicts['model'] = self.model.state_dict()

        dicts['optimizer'] = self.optimizer.state_dict()

        os.makedirs(self.checkpoints, exist_ok=True)

        save_path = os.path.join(self.checkpoints, save_name)
        t.save(dicts, save_path)

        logging.info("Save model to {}".format(save_path))


    def load_model(self, model_path, load_opt=True):
        if not os.path.exists(model_path):
            print(f'{model_path} does not exist, please check it')
        try:
            dicts = t.load(model_path)
        except Exception as e:
            print(e.args)
            print('Load model failed')
            return

        self.optimizer.load_state_dict(dicts['optimizer'])
        self.model.load_state_dict(dicts['model'])
        logging.info(f'Load model from {model_path} successfully...')


    def eval(self):
        self.model.eval()
        tbar = tqdm(enumerate(self.val_loader), desc='Model Evaluation', unit='batch', total=len(self.val_loader))
        correct = 0
        total = 0
        care_total = 0
        care_correct = 0
        self.model.eval()
        with t.no_grad():
            for it, (image, mask) in tbar:
                image = image.to(self.device)
                # mask = mask.to(self.device)

                pred = self.model(image).cpu()

                total += mask.numel()
                
                care_mask = mask < 7
                care_total += care_mask.sum().item()


                correct += (mask == pred.argmax(1)).sum().item()
                # print(mask[care_mask].shape, pred.argmax(1)[care_mask].shape)
                true_mask = (mask[care_mask] == pred.argmax(1)[care_mask])
                care_correct += true_mask.sum().item()

        return float(correct) / total, float(care_correct) / care_total
Example #22
0
    def train(pathDirData, pathFileTrain, pathFileVal, nnArchitecture,
              nnIsTrained, nnClassCount, trBatchSize, trMaxEpoch, transResize,
              transCrop, launchTimestamp, checkpoint):

        #-------------------- SETTINGS: NETWORK ARCHITECTURE
        if nnArchitecture == 'DENSE-NET-121':
            model = DenseNet121(nnClassCount, nnIsTrained).cuda()
        elif nnArchitecture == 'DENSE-NET-169':
            model = DenseNet169(nnClassCount, nnIsTrained).cuda()
        elif nnArchitecture == 'DENSE-NET-201':
            model = DenseNet201(nnClassCount, nnIsTrained).cuda()
        elif nnArchitecture == 'RES-NET-18':
            model = ResNet18(nnClassCount, nnIsTrained).cuda()
        elif nnArchitecture == 'RES-NET-50':
            model = ResNet50(nnClassCount, nnIsTrained).cuda()

        model = torch.nn.DataParallel(model).cuda()

        #-------------------- SETTINGS: DATA TRANSFORMS
        normalize = transforms.Normalize([0.485, 0.456, 0.406],
                                         [0.229, 0.224, 0.225])

        transformList = []
        transformList.append(transforms.RandomResizedCrop(transCrop))
        transformList.append(transforms.RandomHorizontalFlip())
        transformList.append(transforms.ToTensor())
        transformList.append(normalize)
        transformSequence = transforms.Compose(transformList)

        #-------------------- SETTINGS: DATASET BUILDERS
        datasetTrain = DatasetGenerator(pathImageDirectory=pathDirData,
                                        pathDatasetFile=pathFileTrain,
                                        transform=transformSequence)
        datasetVal = DatasetGenerator(pathImageDirectory=pathDirData,
                                      pathDatasetFile=pathFileVal,
                                      transform=transformSequence)

        dataLoaderTrain = DataLoader(dataset=datasetTrain,
                                     batch_size=trBatchSize,
                                     shuffle=True,
                                     num_workers=24,
                                     pin_memory=True)
        dataLoaderVal = DataLoader(dataset=datasetVal,
                                   batch_size=trBatchSize,
                                   shuffle=False,
                                   num_workers=24,
                                   pin_memory=True)

        #-------------------- SETTINGS: OPTIMIZER & SCHEDULER
        optimizer = optim.Adam(model.parameters(),
                               lr=0.0001,
                               betas=(0.9, 0.999),
                               eps=1e-08,
                               weight_decay=1e-5)
        # scheduler = ReduceLROnPlateau(optimizer, factor = 0.1, patience = 5, mode = 'min')
        scheduler = CosineAnnealingWarmRestarts(optimizer,
                                                T_0=10,
                                                eta_min=1e-6)

        #-------------------- SETTINGS: LOSS
        loss = torch.nn.BCELoss(size_average=True)

        #---- Load checkpoint
        if checkpoint != None:
            modelCheckpoint = torch.load(checkpoint)
            model.load_state_dict(modelCheckpoint['state_dict'])
            optimizer.load_state_dict(modelCheckpoint['optimizer'])

        #---- TRAIN THE NETWORK

        lossMIN = 100000

        for epochID in range(0, trMaxEpoch):

            timestampTime = time.strftime("%H%M%S")
            timestampDate = time.strftime("%d%m%Y")
            timestampSTART = timestampDate + '-' + timestampTime

            ChexnetTrainer.epochTrain(model, dataLoaderTrain, optimizer,
                                      scheduler, trMaxEpoch, nnClassCount,
                                      loss)
            lossVal, losstensor = ChexnetTrainer.epochVal(
                model, dataLoaderVal, optimizer, scheduler, trMaxEpoch,
                nnClassCount, loss)

            timestampTime = time.strftime("%H%M%S")
            timestampDate = time.strftime("%d%m%Y")
            timestampEND = timestampDate + '-' + timestampTime

            #             scheduler.step(losstensor.data[0])
            # scheduler.step(losstensor.data)
            scheduler.step(losstensor)

            if lossVal < lossMIN:
                lossMIN = lossVal
                torch.save(
                    {
                        'epoch': epochID + 1,
                        'state_dict': model.state_dict(),
                        'best_loss': lossMIN,
                        'optimizer': optimizer.state_dict()
                    }, './models/m-' + launchTimestamp + '.pth.tar')
                print('Epoch [' + str(epochID + 1) + '] [save] [' +
                      timestampEND + '] loss= ' + str(lossVal))
            else:
                print('Epoch [' + str(epochID + 1) + '] [----] [' +
                      timestampEND + '] loss= ' + str(lossVal))
Example #23
0
def trainer(model,
            optimizer,
            train_loader,
            test_loader,
            epochs=5,
            gpus=1,
            tasks=1,
            classifacation=False,
            mae=False,
            pb=True,
            out="model.pt",
            cyclic=False,
            verbose=True):
    device = next(model.parameters()).device
    if classifacation:
        tracker = trackers.ComplexPytorchHistory(
        ) if tasks > 1 else trackers.PytorchHistory(
            metric=metrics.roc_auc_score, metric_name='roc-auc')
    else:
        tracker = trackers.ComplexPytorchHistory(
        ) if tasks > 1 else trackers.PytorchHistory()

    earlystopping = EarlyStopping(patience=50, delta=1e-5)
    if cyclic:
        lr_red = CosineAnnealingWarmRestarts(optimizer, T_0=20)
    else:
        lr_red = ReduceLROnPlateau(optimizer,
                                   mode='min',
                                   factor=0.8,
                                   patience=20,
                                   cooldown=0,
                                   verbose=verbose,
                                   threshold=1e-4,
                                   min_lr=1e-8)

    for epochnum in range(epochs):
        train_loss = 0
        test_loss = 0
        train_iters = 0
        test_iters = 0
        model.train()
        if pb:
            gen = tqdm(enumerate(train_loader))
        else:
            gen = enumerate(train_loader)
        for i, (drugfeats, value) in gen:
            optimizer.zero_grad()
            drugfeats, value = drugfeats.to(device), value.to(device)
            pred, attn = model(drugfeats)

            if classifacation:
                mse_loss = torch.nn.functional.binary_cross_entropy_with_logits(
                    pred, value).mean()
            elif mae:
                mse_loss = torch.nn.functional.l1_loss(pred, value).mean()
            else:
                mse_loss = torch.nn.functional.mse_loss(pred, value).mean()
            mse_loss.backward()
            torch.nn.utils.clip_grad_value_(model.parameters(), 10.0)
            optimizer.step()
            train_loss += mse_loss.item()
            train_iters += 1
            tracker.track_metric(pred=pred.detach().cpu().numpy(),
                                 value=value.detach().cpu().numpy())

        tracker.log_loss(train_loss / train_iters, train=True)
        tracker.log_metric(internal=True, train=True)

        model.eval()
        with torch.no_grad():
            for i, (drugfeats, value) in enumerate(test_loader):
                drugfeats, value = drugfeats.to(device), value.to(device)
                pred, attn = model(drugfeats)

                if classifacation:
                    mse_loss = torch.nn.functional.binary_cross_entropy_with_logits(
                        pred, value).mean()
                elif mae:
                    mse_loss = torch.nn.functional.l1_loss(pred, value).mean()
                else:
                    mse_loss = torch.nn.functional.mse_loss(pred, value).mean()
                test_loss += mse_loss.item()
                test_iters += 1
                tracker.track_metric(pred.detach().cpu().numpy(),
                                     value.detach().cpu().numpy())
        tracker.log_loss(train_loss / train_iters, train=False)
        tracker.log_metric(internal=True, train=False)

        lr_red.step(test_loss / test_iters)
        earlystopping(test_loss / test_iters)
        if verbose:
            print("Epoch", epochnum, train_loss / train_iters,
                  test_loss / test_iters, tracker.metric_name,
                  tracker.get_last_metric(train=True),
                  tracker.get_last_metric(train=False))

        if out is not None:
            if gpus == 1:
                state = model.state_dict()
                heads = model.nheads
            else:
                state = model.module.state_dict()
                heads = model.module.nheads
            torch.save(
                {
                    'model_state': state,
                    'opt_state': optimizer.state_dict(),
                    'history': tracker,
                    'nheads': heads,
                    'ntasks': tasks
                }, out)
        if earlystopping.early_stop:
            break
    return model, tracker
Example #24
0
class Trainer:
    _classifier: Classifier
    _train_set: TinyImagenetDataset
    _test_set: TinyImagenetDataset
    _results_path: Path
    _device: torch.device
    _batch_size: int
    _num_workers: int
    _num_visual: int
    _aug_degree: Dict
    _lr: float
    _lr_min: float
    _stopper: Stopper
    _labels_num2txt: Dict
    _freeze: Dict
    _weight_decay: float
    _label_smooth: float
    _period_cosine: int

    _net_path: Path
    _tensorboard_path: Path
    _writer: SummaryWriter
    _optimizer: Optimizer
    _scheduler: CosineAnnealingWarmRestarts
    _loss_func_smoothed: Callable
    _loss_func_one_hot: nn.Module
    _curr_epoch: int
    _vis_per_batch: int

    def __init__(self, classifier: Classifier, train_set: TinyImagenetDataset,
                 test_set: TinyImagenetDataset, results_path: Path,
                 device: torch.device, batch_size: int, num_workers: int,
                 num_visual: int, aug_degree: Dict, lr: float, lr_min: float,
                 stopper: Stopper, labels_num2txt: Dict, freeze: Dict,
                 weight_decay: float, label_smooth: float, period_cosine: int):

        self._classifier = classifier
        self._train_set = train_set
        self._test_set = test_set
        self._results_path = results_path
        self._device = device
        self._batch_size = batch_size
        self._num_workers = num_workers
        self._num_visual = num_visual
        self._aug_degree = aug_degree
        self._lr = lr
        self._lr_min = lr_min
        self._stopper = stopper
        self._labels_num2txt = labels_num2txt
        self._freeze = freeze
        self._weight_decay = weight_decay
        self._label_smooth = label_smooth
        self._period_cosine = period_cosine

        self._classifier.to(self._device)

        self._net_path, self._tensorboard_path = self._results_path / 'net', self._results_path / 'tensorboard'

        for folder in [self._net_path, self._tensorboard_path]:
            folder.mkdir(exist_ok=True, parents=True)

        self._writer = SummaryWriter(log_dir=str(self._tensorboard_path))

        self._loss_func_smoothed = cross_entropy_with_probs
        self._loss_func_one_hot = nn.CrossEntropyLoss()

        self._optimizer = torch.optim.SGD(self._classifier.parameters(),
                                          lr=self._lr,
                                          momentum=0.9,
                                          weight_decay=self._weight_decay)

        self._scheduler = CosineAnnealingWarmRestarts(self._optimizer,
                                                      T_0=self._period_cosine,
                                                      eta_min=self._lr_min,
                                                      T_mult=2)

        self._curr_epoch = 0
        self._vis_per_batch = self._calc_vis_per_batch()

    def _calc_vis_per_batch(self) -> int:
        num_batch = ceil(len(self._test_set) / self._batch_size)
        return ceil(self._num_visual / num_batch)

    def train(self, num_epoch: int) -> Tuple[float, float, int]:
        accuracy_max = 0
        accuracy_test = 0
        loss_test = 0

        for epoch in range(num_epoch):

            self._curr_epoch = epoch

            self._set_freezing()

            accuracy_train, loss_train = self._train_epoch()

            accuracy_test, loss_test = self.test()

            self._writer.flush()

            self._stopper.update(accuracy_test)

            meta = {
                'epoch': self._curr_epoch,
                'accuracy_test': accuracy_test,
                'accuracy_train': accuracy_train,
                'loss_test': loss_test,
                'loss_train': loss_train,
                'lr': self._scheduler.get_lr()[0]
            }

            self._classifier.save(
                self._net_path / f'net_epoch_{self._curr_epoch}.pth', meta)

            if accuracy_test > accuracy_max:
                accuracy_max = accuracy_test

                self._classifier.save(self._net_path / 'best.pth', meta)

            if self._stopper.is_need_stop():
                break

        self._writer.close()
        return accuracy_test, loss_test, self._curr_epoch

    def _train_epoch(self) -> Tuple[float, float]:
        self._classifier.train()
        self._set_aug_train()
        train_loader = DataLoader(self._train_set,
                                  batch_size=self._batch_size,
                                  num_workers=self._num_workers,
                                  shuffle=True)

        num_batches = len(train_loader)
        correct = 0

        train_tqdm = tqdm(train_loader, desc=f'train_{self._curr_epoch}')
        loss_avg = AvgMoving()
        for curr_batch, (images, labels, _) in enumerate(train_tqdm):
            self._optimizer.zero_grad()

            images = images.to(self._device)
            labels = labels.to(self._device)

            output = self._classifier(images)

            if self._label_smooth > 0:
                smoothed_labels = smooth_one_hot(labels=labels,
                                                 num_classes=200,
                                                 smoothing=self._label_smooth)
                loss = self._loss_func_smoothed(output, smoothed_labels)
            else:
                loss = self._loss_func_one_hot(output, labels)

            loss.backward()

            self._scheduler.step(self._curr_epoch +
                                 (curr_batch + 1) / num_batches)
            self._optimizer.step()

            loss = loss.item()
            loss_avg.add(loss)

            pred = output.argmax(dim=1)
            correct += torch.eq(pred, labels).sum().item()

            train_tqdm.set_postfix({'Avg train loss': round(loss_avg.avg, 4)})

        accuracy = correct / len(train_loader.dataset)
        self._add_writer_metrics(loss_avg.avg, accuracy, 'train')

        return accuracy, loss_avg.avg

    def test(self) -> Tuple[float, float]:
        self._classifier.eval()
        with torch.no_grad():
            test_loader = DataLoader(self._test_set,
                                     batch_size=self._batch_size,
                                     num_workers=self._num_workers,
                                     shuffle=True)
            correct = 0
            test_tqdm = tqdm(test_loader,
                             desc=f'test_{self._curr_epoch}',
                             leave=False)
            loss_avg = AvgMoving()

            labels_pred_list = []

            worst_pred_list = []
            best_pred_list = []
            some_pred_list = []

            for images, labels, paths in test_tqdm:
                images = images.to(self._device)
                labels = labels.to(self._device)

                output = self._classifier(images)
                output = softmax(output, dim=1)

                pred = output.argmax(dim=1)
                correct += torch.eq(pred, labels).sum().item()

                loss_avg.add(self._loss_func_one_hot(output, labels).item())

                labels = labels.detach().cpu().numpy()
                output = output.detach().cpu().numpy()
                pred = pred.detach().cpu().numpy()
                prob = output[np.arange(np.size(labels)), pred]

                # prepare data for cunfusion matrix and hists
                labels_pred = np.hstack(
                    (labels[:, np.newaxis], pred[:, np.newaxis]))
                labels_pred_list.append(np.copy(labels_pred))

                # prepare data for visualiztion
                prob_in_labels = output[np.arange(np.size(labels)), labels]

                worst_pred, best_pred, some_pred = self._prepare_data_for_vis(
                    np.copy(pred), np.copy(prob), np.copy(labels),
                    np.copy(prob_in_labels), paths)

                worst_pred_list.append(worst_pred)
                best_pred_list.append(best_pred)
                some_pred_list.append(some_pred)

            accuracy = correct / len(test_loader.dataset)

            self._add_writer_metrics(loss_avg.avg, accuracy, 'test')
            self._visual_confusion_and_hists(labels_pred_list)
            self._visual_gt_and_pred(worst_pred_list, 'Worst_pred_pic')
            self._visual_gt_and_pred(best_pred_list, 'Best_pred_pic')
            self._visual_gt_and_pred(some_pred_list, 'Some_pred_pic')

        self._classifier.train()
        return accuracy, loss_avg.avg

    def _freeze_except_k_last(self, k_last: int) -> None:
        num_layers = 0
        for _ in self._classifier._net.children():
            num_layers += 1

        for i, layer in enumerate(self._classifier._net.children()):
            if i + k_last < num_layers:
                for param in layer.parameters():
                    param.requires_grad = False

    def _unfreeze_all(self) -> None:
        for param in self._classifier.parameters():
            param.requires_grad = True

    def _set_freezing(self) -> None:
        curr_epoch = str(self._curr_epoch)
        if curr_epoch in tuple(self._freeze.keys()):
            self._unfreeze_all()
            self._freeze_except_k_last(self._freeze[curr_epoch])

            self._optimizer = torch.optim.SGD(filter(
                lambda p: p.requires_grad, self._classifier.parameters()),
                                              lr=self._lr,
                                              momentum=0.9,
                                              weight_decay=self._weight_decay)
            self._scheduler = CosineAnnealingWarmRestarts(
                self._optimizer,
                T_0=self._period_cosine,
                eta_min=self._lr_min,
                T_mult=2)

    def _set_aug_train(self) -> None:
        curr_epoch = str(self._curr_epoch)
        if curr_epoch in tuple(self._aug_degree.keys()):
            self._train_set.set_transforms(self._aug_degree[curr_epoch])

    def _add_writer_metrics(self, loss: float, accuracy: float,
                            mode: str) -> None:
        self._writer.add_scalars('loss', {f'loss_{mode}': loss},
                                 self._curr_epoch)
        self._writer.add_scalars('accuracy', {f'accuracy_{mode}': accuracy},
                                 self._curr_epoch)


# VISUALIZATON

    def _prepare_data_for_vis(
            self, pred: np.ndarray, prob: np.ndarray, labels: np.ndarray,
            prob_in_labels: np.ndarray,
            paths: Path) -> List[Tuple[List[Path], np.ndarray, np.ndarray]]:
        prob_idx = np.argsort(prob_in_labels)
        prob_idx_list = list()
        prob_idx_list.append(prob_idx[0:self._vis_per_batch])  # worst preds
        prob_idx_list.append(prob_idx[-1:-self._vis_per_batch -
                                      1:-1])  # best preds
        prob_idx_list.append(np.random.choice(
            prob_idx, self._vis_per_batch))  # some preds

        data_list = []
        for idx in prob_idx_list:
            pred_and_prob = np.hstack(
                (pred[idx, np.newaxis], prob[idx, np.newaxis]))

            labels_and_prob_in_labels = np.hstack(
                (labels[idx, np.newaxis], prob_in_labels[idx, np.newaxis]))

            data = ([paths[k]
                     for k in idx], pred_and_prob, labels_and_prob_in_labels)
            data_list.append(data)

        return data_list

    def _visual_gt_and_pred(self, data_list: List[Tuple[List[Path], np.ndarray,
                                                        np.ndarray]],
                            txt: str) -> None:
        num_batch = ceil(len(self._test_set) / self._batch_size)
        num_elem = num_batch * self._vis_per_batch

        images_array = np.zeros((num_elem, 64, 64, 3), int)
        pred_and_prob_array = np.zeros((num_elem, 2))
        labels_and_prob_in_labels_array = np.zeros((num_elem, 2))
        filenames = []
        k_im = 0
        for i, (paths, pred_and_prob,
                labels_and_prob_in_labels) in enumerate(data_list):
            filenames += paths
            for path in paths:
                images_array[k_im, :] = np.array(
                    Image.open(path).convert('RGB'))
                k_im += 1

            pred_and_prob_array[i * self._vis_per_batch: (i + 1) * self._vis_per_batch, :] = \
                pred_and_prob

            labels_and_prob_in_labels_array[i * self._vis_per_batch: (i + 1) * self._vis_per_batch, :] = \
                labels_and_prob_in_labels

        idx = np.random.choice(range(num_elem),
                               self._num_visual,
                               replace=False)
        images_array = images_array[idx, :]
        pred_and_prob_array = pred_and_prob_array[idx, :]
        labels_and_prob_in_labels_array = labels_and_prob_in_labels_array[
            idx, :]

        filenames = [Path(filenames[idx_curr]).stem for idx_curr in idx]

        height_fig = self._num_visual
        width_fig = 3
        height_cell = 0.95 * (height_fig / self._num_visual) / height_fig
        width_im_cell = 1 / width_fig
        left_im_cell = 0 / width_fig
        width_txt_cell = 1.8 / width_fig
        left_txt_cell = 1.1 / width_fig
        bottom_cell = [x / height_fig for x in range(height_fig)]

        fig = plt.figure(figsize=(width_fig, height_fig), tight_layout=False)

        for k in range(self._num_visual):
            fig.add_axes(
                (left_im_cell, bottom_cell[k], width_im_cell, height_cell))
            plt.axis('off')
            plt.imshow(images_array[k, :], aspect='auto')

            fig.add_axes(
                (left_txt_cell, bottom_cell[k], width_txt_cell, height_cell))
            str_pic = f'{filenames[k]} \n' \
                f'gt: {round(labels_and_prob_in_labels_array[k, 1], 2)}\n' \
                f'({int(labels_and_prob_in_labels_array[k, 0])}) ' \
                f'{self._labels_num2txt[labels_and_prob_in_labels_array[k, 0]]}\n' \
                f'pred: {round(pred_and_prob_array[k, 1], 2)}\n' \
                f'({int(pred_and_prob_array[k, 0])}) ' \
                f'{self._labels_num2txt[pred_and_prob_array[k, 0]]}\n' \

            plt.text(0, 0.5, str_pic, verticalalignment='center')
            plt.axis('off')

        self._writer.add_figure(txt, fig, self._curr_epoch)

        plt.close(fig)

    def _visual_confusion_and_hists(
            self, labels_pred_list: List[np.ndarray]) -> None:
        labels = np.zeros(len(self._test_set))
        pred = np.zeros(len(self._test_set))

        k_batch = 0

        for labels_pred in labels_pred_list:
            diap = min((self._batch_size, np.shape(labels_pred)[0]))
            labels[k_batch * self._batch_size:k_batch * self._batch_size +
                   diap] = labels_pred[:, 0]
            pred[k_batch * self._batch_size:k_batch * self._batch_size +
                 diap] = labels_pred[:, 1]

            k_batch += 1

        confusion_matrix_array = confusion_matrix(y_pred=pred,
                                                  y_true=labels).astype(float)

        confusion_matrix_array /= 50

        fig = plt.figure(figsize=(12, 12))
        conf_map = plt.imshow(confusion_matrix_array,
                              cmap="gist_heat",
                              interpolation="nearest")
        plt.colorbar(mappable=conf_map)
        self._writer.add_figure('Confusion_matrix', fig, self._curr_epoch)

        plt.close(fig)

        self._visual_hists(confusion_matrix_array)

    def _visual_hists(self, confusion_matrix_array: np.ndarray) -> None:
        num_col = 20
        correct = np.diag(confusion_matrix_array) * 100

        idx_correct = np.argsort(correct)

        idx_best = idx_correct[-1:-num_col - 1:-1]
        idx_worst = idx_correct[0:num_col]

        self._visual_hist(np.copy(correct[idx_best]),
                          idx_best,
                          ylabel='Correct predicts, %',
                          title='Best predicts',
                          tag='Best_predicts_hist')

        self._visual_hist(np.copy(correct[idx_worst]),
                          idx_worst,
                          ylabel='Correct predicts, %',
                          title='Worst predicts',
                          tag='Worst_predicts_hist')

    def _visual_hist(self, data: np.ndarray, labels: np.ndarray, ylabel: str,
                     title: str, tag: str) -> None:
        num_col = np.size(data)
        len_txt = 20
        xlim = (0, num_col + 1)
        ylim = (0, max(np.max(data) * 1.1, 1e-5))

        fig, ax = plt.subplots(figsize=(7, 7), facecolor='white')
        ax.set(ylabel=ylabel, ylim=ylim, xlim=xlim)

        labels_on_graph = []
        for k in range(np.size(labels)):
            str_tick = self._labels_num2txt[labels[k]]
            len_str_tick = len(str_tick)
            if len_str_tick < len_txt:
                str_tick = str_tick + ' ' * (len_str_tick - len_txt)
            elif len_str_tick > len_txt:
                str_tick = str_tick[0:len_txt]

            labels_on_graph.append(str_tick)

        for i in range(num_col):
            val = data[i]
            ax.text(i + 1,
                    val + ylim[1] * 0.01,
                    np.round(val).astype(int),
                    horizontalalignment='center')
            ax.vlines(x=i + 1,
                      ymin=0,
                      ymax=val,
                      color='firebrick',
                      alpha=0.7,
                      linewidth=20)

        plt.xticks(range(1, num_col + 1), labels_on_graph, rotation=-90)
        plt.title(title)
        fig.tight_layout()

        self._writer.add_figure(tag, fig, self._curr_epoch)

        plt.close(fig)
Example #25
0
    def fit(self,
            model,
            num_epoch=200,
            batch_size=150,
            learning_rate=0.01,
            momentum=0.9,
            l2=1e-4,
            t_0=200,
            eta_min=1e-4,
            grad_clip=0,
            mode="standard",
            warm_up_epoch=0,
            retrain=False,
            device="cuda:0",
            verbose=True):
        """
        Train a given model on the trainer dataset using the givens hyperparameters.

        :param model: The model to train.
        :param num_epoch: Maximum number of epoch during the training. (Default=200)
        :param batch_size: The batch size that will be used during the training. (Default=150)
        :param learning_rate: Start learning rate of the SGD optimizer. (Default=0.1)
        :param momentum: Momentum that will be used by the SGD optimizer. (Default=0.9)
        :param l2: L2 regularization coefficient.
        :param t_0: Number of epoch before the first restart. (Default=200)
        :param eta_min: Minimum value of the learning rate. (Default=1e-4)
        :param grad_clip: Max norm of the gradient. If 0, no clipping will be applied on the gradient. (Default=0)
        :param mode: The training type: Option: Standard training (No mixup) (Default)
                                                Mixup (Standard manifold mixup)
                                                AdaMixup (Adaptative mixup)
                                                ManifoldAdaMixup (Manifold Adaptative mixup)
        :param warm_up_epoch: Number of iteration before activating mixup. (Default=True)
        :param retrain: If false, the weights of the model will initialize. (Default=False)
        :param device: The device on which the training will be done. (Default="cuda:0", first GPU)
        :param verbose: If true, show the progress of the training. (Default=True)
        """
        # Indicator for early stopping
        best_accuracy = 0
        best_epoch = -1
        current_mode = "Standard"

        # We get the appropriate loss because mixup loss will always be bigger than standard loss.
        last_saved_loss = float("inf")

        # Initialization of the model.
        self.device = device
        self.model = model.to(device)
        self.model.configure_mixup_module(batch_size, device)

        if retrain:
            start_epoch, last_saved_loss, best_accuracy = self.model.restore(
                self.save_path)
        else:
            self.model.apply(init_weights)
            start_epoch = 0

        # Initialization of the dataloader
        train_loader = dataset_to_loader(self.trainset,
                                         batch_size,
                                         shuffle=True,
                                         pin_memory=self.pin_memory)
        valid_loader = dataset_to_loader(self.validset,
                                         batch_size,
                                         shuffle=False,
                                         pin_memory=self.pin_memory)

        # Initialization of the optimizer and the scheduler
        optimizer = torch.optim.SGD(self.model.parameters(),
                                    lr=learning_rate,
                                    momentum=momentum,
                                    weight_decay=l2,
                                    nesterov=True)

        n_iters = len(train_loader)
        scheduler = CosineAnnealingWarmRestarts(optimizer,
                                                T_0=t_0 * n_iters,
                                                T_mult=1,
                                                eta_min=eta_min)
        scheduler.step(start_epoch * n_iters)

        # Go in training mode to activate mixup module
        self.model.train()

        with tqdm(total=num_epoch, initial=start_epoch,
                  disable=(not verbose)) as t:
            for epoch in range(start_epoch, num_epoch):

                _grad_clip = 0 if epoch > num_epoch / 2 else grad_clip
                current_mode = mode if warm_up_epoch <= epoch else current_mode

                # We make a training epoch
                if current_mode == "Mixup":
                    training_loss = self.mixup_epoch(train_loader, optimizer,
                                                     scheduler, _grad_clip)
                else:
                    training_loss = self.standard_epoch(
                        train_loader, optimizer, scheduler, _grad_clip)

                self.model.eval()
                current_accuracy, val_loss = self.accuracy(
                    dt_loader=valid_loader, get_loss=True)
                self.model.train()

                # ------------------------------------------------------------------------------------------
                #                                   EARLY STOPPING PART
                # ------------------------------------------------------------------------------------------

                if (val_loss < last_saved_loss and current_accuracy >= best_accuracy) or \
                        (val_loss < last_saved_loss*(1+self.tol) and current_accuracy > best_accuracy):
                    self.save_checkpoint(epoch, val_loss, current_accuracy)
                    best_accuracy = current_accuracy
                    last_saved_loss = val_loss
                    best_epoch = epoch

                if verbose:
                    t.postfix = "train loss: {:.4f}, val loss: {:.4f}, val acc: {:.2f}%, best acc: {:.2f}%, " \
                                "best epoch: {}, epoch type: {}".format(
                                 training_loss, val_loss, current_accuracy * 100, best_accuracy * 100, best_epoch + 1,
                                 current_mode)
                    t.update()
        self.model.restore(self.save_path)
    loss_list = []
    auc = 0.0
    auc_list = [0]
    for epoch in tqdm(range(epoch_num)):
        epoch_loss = 0
        for batchX, batchY in video_loader:
            batchX = batchX.cuda()
            batchY = batchY.cuda()
            score_pred = net(batchX).cuda()
            batch_loss = loss_func(score_pred, batchY)
            epoch_loss += batch_loss
            reg_optimizer.zero_grad()
            batch_loss.backward()
            reg_optimizer.step()
        print('Epoch:{}/{} Loss:{}'.format(epoch + 1, epoch_num, epoch_loss))
        scheduler.step(epoch)
        loss_list.append(epoch_loss)

        if (epoch + 1) % 10 == 0:
            net.eval()
            score_list = rn_predict(seg_dir=test_seg_dir,
                                    input_dim=input_dim,
                                    net=net)
            _, _, _, cur_auc = get_roc_metric(score_list=score_list,
                                              include_normal=True,
                                              anno_dir=anno_dir,
                                              path_dir=path_dir)
            _, _, _, oc_auc = get_roc_metric(score_list=score_list,
                                             include_normal=False,
                                             anno_dir=anno_dir,
                                             path_dir=path_dir)
Example #27
0
class Trainer:
    def __init__(self, device, config, model, criterion, dataloader, data_transformer=None, tensorboard=True, meta_data=None):
        super().__init__()

        config['running_loss_range'] = config['running_loss_range'] if 'running_loss_range' in config else 50

        self.device = torch.device('cuda' if torch.cuda.is_available() and device == 'cuda' else 'cpu')
        self.config = config

        self.model = model.to(device)
        self.criterion = criterion

        self.ds_length = len(dataloader.dataset)
        self.batch_size = dataloader.batch_size
        self.dataloader = dataloader
        self.data_transformer = data_transformer

        self.epoch_loss_history = []
        self.content_loss_history = []
        self.style_loss_history = []
        self.total_variation_loss_history = []
        self.loss_history = []
        self.lr_history = []
        self.meta_data = meta_data

        self.progress_bar = trange(
            math.ceil(self.ds_length / self.batch_size) * self.config['epochs'],
            leave=True
        )

        if self.config['lr_scheduler'] == 'CyclicLR':
            self.optimizer = optim.SGD(self.model.parameters(), lr=config['max_lr'], nesterov=True, momentum=0.9)
            self.scheduler = CyclicLR(
                optimizer=self.optimizer,
                base_lr=config['min_lr'],
                max_lr=config['max_lr'],
                step_size_up=self.config['lr_step_size'],
                mode='triangular2'
            )
        elif self.config['lr_scheduler'] == 'CosineAnnealingLR':
            self.optimizer = optim.Adam(self.model.parameters(), lr=config['max_lr'])
            self.scheduler = CosineAnnealingLR(
                optimizer=self.optimizer,
                T_max=self.config['lr_step_size']
            )
        elif self.config['lr_scheduler'] == 'ReduceLROnPlateau':
            self.optimizer = optim.Adam(self.model.parameters(), lr=config['max_lr'])
            self.scheduler = ReduceLROnPlateau(
                optimizer=self.optimizer,
                patience=100
            )
        elif self.config['lr_scheduler'] == 'StepLR':
            self.optimizer = optim.Adam(self.model.parameters(), lr=config['max_lr'])
            self.scheduler = StepLR(
                optimizer=self.optimizer,
                step_size=self.config['lr_step_size'],
                gamma=float(self.config['lr_multiplicator'])
            )
        elif self.config['lr_scheduler'] == 'CosineAnnealingWarmRestarts':
            self.optimizer = optim.Adam(self.model.parameters(), lr=config['max_lr'])
            self.scheduler = CosineAnnealingWarmRestarts(
                optimizer=self.optimizer,
                T_0=self.config['lr_step_size'],
                eta_min=config['min_lr'],
                T_mult=int(self.config['lr_multiplicator'])
            )
        else:
            self.optimizer = optim.Adam(self.model.parameters(), lr=config['max_lr'])
            self.scheduler = None

        if tensorboard:
            self.tensorboard_writer = SummaryWriter(log_dir=os.path.join('./runs', config['name']))
        else:
            self.tensorboard_writer = None

    def load_checkpoint(self):
        name = self.config['name']
        path = f'./checkpoints/{name}.pth'

        if os.path.exists(path):
            checkpoint = torch.load(path)

            self.model.load_state_dict(checkpoint['model_state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

            if self.scheduler:
                self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

            self.content_loss_history = checkpoint['content_loss_history']
            self.style_loss_history = checkpoint['style_loss_history']
            self.total_variation_loss_history = checkpoint['total_variation_loss_history']
            self.loss_history = checkpoint['loss_history']
            self.lr_history = checkpoint['lr_history']

            del checkpoint
            torch.cuda.empty_cache()

    def train(self):
        start = time()

        for epoch in range(self.config['epochs']):
            self.epoch_loss_history = []

            for i, batch in enumerate(self.dataloader):
                self.do_training_step(i, batch)
                self.do_progress_bar_step(epoch, self.config['epochs'], i)

                if self.config['lr_scheduler'] == 'ReduceLROnPlateau':
                    self.scheduler.step(self.loss_history[-1])
                elif self.scheduler:
                    self.scheduler.step()

                if i % self.config['save_checkpoint_interval'] == 0:
                    self.save_checkpoint(f'./checkpoints/{self.config["name"]}.pth')

                if time() - start >= self.config['max_runtime'] != 0:
                    break

        torch.cuda.empty_cache()

    def do_training_step(self, i, batch):
        self.model.train()

        with torch.autograd.detect_anomaly():
            try:

                if self.data_transformer:
                    x, y = self.data_transformer(batch)
                else:
                    x, y = batch

                x = x.to(self.device)
                y = y.to(self.device)
                self.optimizer.zero_grad()

                preds = self.model(x)

                loss = self.criterion(preds, y)
                loss.backward()
                self.optimizer.step()

                self.lr_history.append(self.optimizer.param_groups[0]['lr'])
                self.epoch_loss_history.append(self.criterion.loss_val)

                self.content_loss_history.append(self.criterion.content_loss_val)
                self.style_loss_history.append(self.criterion.style_loss_val)
                self.total_variation_loss_history.append(self.criterion.total_variation_loss_val)
                self.loss_history.append(self.criterion.loss_val)

                if self.tensorboard_writer and i % self.config['save_checkpoint_interval'] == 0:
                    grid_y = torchvision.utils.make_grid(y)
                    grid_preds = torchvision.utils.make_grid(preds)

                    self.tensorboard_writer.add_image('Inputs', grid_y, 0)
                    self.tensorboard_writer.add_image('Predictions', grid_preds, 0)

                    # writer.add_graph(network, images)
                    self.tensorboard_writer.add_scalar(
                        'Content Loss',
                        self.content_loss_history[-1],
                        len(self.content_loss_history) - 1
                    )

                    self.tensorboard_writer.add_scalar(
                        'Style Loss',
                        self.style_loss_history[-1],
                        len(self.style_loss_history) - 1
                    )

                    self.tensorboard_writer.add_scalar(
                        'TV Loss',
                        self.total_variation_loss_history[-1],
                        len(self.total_variation_loss_history) - 1
                    )

                    self.tensorboard_writer.add_scalar(
                        'Total Loss',
                        self.loss_history[-1],
                        len(self.loss_history) - 1
                    )

                    self.tensorboard_writer.add_scalar(
                        'Learning Rate',
                        self.lr_history[-1],
                        len(self.lr_history) - 1
                    )

                    self.tensorboard_writer.close()
            except:
                self.load_checkpoint()

    def do_validation_step(self):
        self.model.eval()

    def do_progress_bar_step(self, epoch, epochs, i):
        avg_epoch_loss = sum(self.epoch_loss_history) / (i + 1)

        if len(self.loss_history) >= self.config['running_loss_range']:
            running_loss = sum(
                self.loss_history[-self.config['running_loss_range']:]
            ) / self.config['running_loss_range']
        else:
            running_loss = 0

        if len(self.loss_history) > 0:
            self.progress_bar.set_description(
                f'Name: {self.config["name"]}, ' +
                f'Loss Network: {self.config["loss_network"]}, ' +
                f'Epoch: {epoch + 1}/{epochs}, ' +
                f'Average Epoch Loss: {avg_epoch_loss:,.2f}, ' +
                f'Running Loss: {running_loss:,.2f}, ' +
                f'Loss: {self.loss_history[-1]:,.2f}, ' +
                f'Learning Rate: {self.lr_history[-1]:,.6f}'
            )
        else:
            self.progress_bar.set_description(
                f'Name: {self.config["name"]}, ' +
                f'Loss Network: {self.config["loss_network"]}, ' +
                f'Epoch: {epoch + 1}/{epochs}, ' +
                f'Average Epoch Loss: {0:,.2f}, ' +
                f'Running Loss: {0:,.2f}, ' +
                f'Loss: {0:,.2f}, ' +
                f'Learning Rate: {0:,.6f}'
            )

        self.progress_bar.update(1)
        self.progress_bar.refresh()

    def save_checkpoint(self, path):
        torch.save({
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None,
            'content_loss_history': self.content_loss_history,
            'style_loss_history': self.style_loss_history,
            'total_variation_loss_history': self.total_variation_loss_history,
            'loss_history': self.loss_history,
            'lr_history': self.lr_history,
            'content_image_size': self.config['content_image_size'],
            'style_image_size': self.config['style_image_size'],
            'network': str(self.config['network']),
            'content_weight': self.config['content_weight'],
            'style_weight': self.config['style_weight'],
            'total_variation_weight': self.config['total_variation_weight'],
            'bottleneck_size': self.config['bottleneck_size'],
            'bottleneck_type': str(self.config['bottleneck_type']),
            'channel_multiplier': self.config['channel_multiplier'],
            'expansion_factor': self.config['expansion_factor'],
            'intermediate_activation_fn': self.config['intermediate_activation_fn'],
            'final_activation_fn': self.config['final_activation_fn'],
            'meta_data': self.meta_data
        }, path)
Example #28
0
def main(args):

    args = parse_args()
    tag = args.tag
    device = torch.device('cuda:0')

    no_epochs = args.epochs
    batch_size = args.batch

    linear_hidden = args.linear
    conv_hidden = args.conv

    #Get train test paths -> later on implement cross val
    steps = get_paths(as_tuples=True, shuffle=True, tag=tag)
    steps_train, steps_test = steps[:int(len(steps) *
                                         .8)], steps[int(len(steps) * .2):]

    transform = transforms.Compose(
        [DepthSegmentationPreprocess(no_data_points=1),
         ToSupervised()])

    dataset_train = SimpleDataset(ids=steps_train,
                                  batch_size=batch_size,
                                  transform=transform,
                                  **SENSORS)
    dataset_test = SimpleDataset(ids=steps_test,
                                 batch_size=batch_size,
                                 transform=transform,
                                 **SENSORS)

    dataloader_params = {
        'batch_size': batch_size,
        'shuffle': True,
        'num_workers': 8
    }  #we've already shuffled paths

    dataset_train = DataLoader(dataset_train, **dataloader_params)
    dataset_test = DataLoader(dataset_test, **dataloader_params)

    batch = next(iter(dataset_test))
    action_shape = batch['action'][0].shape
    img_shape = batch['img'][0].shape
    #Nets
    net = DDPGActor(img_shape=img_shape,
                    numeric_shape=[len(NUMERIC_FEATURES)],
                    output_shape=[2],
                    linear_hidden=linear_hidden,
                    conv_filters=conv_hidden)
    # net = DDPGCritic(actor_out_shape=action_shape, img_shape=img_shape, numeric_shape=[len(NUMERIC_FEATURES)],
    #                         linear_hidden=linear_hidden, conv_filters=conv_filters)

    print(len(steps))
    print(net)
    print(get_n_params(net))
    # save path
    net_path = f'../data/models/imitation/{DATE_TIME}/{net.name}'
    os.makedirs(net_path, exist_ok=True)
    optim_steps = args.optim_steps
    logging_idx = int(len(dataset_train.dataset) / (batch_size * optim_steps))

    writer_train = SummaryWriter(f'{net_path}/train',
                                 max_queue=30,
                                 flush_secs=5)
    writer_test = SummaryWriter(f'{net_path}/test', max_queue=1, flush_secs=5)

    #Optimizers
    optimizer = torch.optim.Adam(net.parameters(),
                                 lr=0.001,
                                 weight_decay=0.0005)

    if args.scheduler == 'cos':
        scheduler = CosineAnnealingWarmRestarts(optimizer,
                                                T_0=optim_steps,
                                                T_mult=2)
    elif args.scheduler == 'one_cycle':
        scheduler = OneCycleLR(optimizer,
                               max_lr=0.001,
                               epochs=no_epochs,
                               steps_per_epoch=optim_steps)

    #Loss function
    loss_function = torch.nn.MSELoss(reduction='sum')
    test_loss_function = torch.nn.MSELoss(reduction='sum')

    best_train_loss = 1e10
    best_test_loss = 1e10

    for epoch_idx in range(no_epochs):
        train_loss = .0
        running_loss = .0
        # critic_running_loss = .0
        avg_max_grad = 0.
        avg_avg_grad = 0.
        for idx, batch in enumerate(iter(dataset_train)):
            global_step = int((len(dataset_train.dataset) / batch_size *
                               epoch_idx) + idx)
            batch = unpack_batch(batch=batch, device=device)
            loss, grad = train(input=batch,
                               label=batch['action'],
                               net=net,
                               optimizer=optimizer,
                               loss_fn=loss_function)
            # loss, grad = train(input=batch, label=batch['q'], net=net, optimizer=optimizer, loss_fn=loss_function)

            avg_max_grad += max([element.max() for element in grad])
            avg_avg_grad += sum([element.mean()
                                 for element in grad]) / len(grad)

            running_loss += loss
            train_loss += loss

            writer_train.add_scalar(tag=f'{net.name}/running_loss',
                                    scalar_value=loss / batch_size,
                                    global_step=global_step)
            writer_train.add_scalar(tag=f'{net.name}/max_grad',
                                    scalar_value=avg_max_grad,
                                    global_step=global_step)
            writer_train.add_scalar(tag=f'{net.name}/mean_grad',
                                    scalar_value=avg_avg_grad,
                                    global_step=global_step)

            if idx % logging_idx == logging_idx - 1:
                print(
                    f'Actor Epoch: {epoch_idx + 1}, Batch: {idx+1}, Loss: {running_loss/logging_idx}, Lr: {scheduler.get_last_lr()[0]}'
                )
                if (running_loss / logging_idx) < best_train_loss:
                    best_train_loss = running_loss / logging_idx
                    torch.save(net.state_dict(), f'{net_path}/train/train.pt')

                writer_train.add_scalar(
                    tag=f'{net.name}/lr',
                    scalar_value=scheduler.get_last_lr()[0],
                    global_step=global_step)
                running_loss = 0.0
                avg_max_grad = 0.
                avg_avg_grad = 0.
                scheduler.step()

        print(
            f'{net.name} best train loss for epoch {epoch_idx+1} - {best_train_loss}'
        )
        writer_train.add_scalar(tag=f'{net.name}/global_loss',
                                scalar_value=train_loss /
                                len(dataset_train.dataset),
                                global_step=(epoch_idx + 1))
        test_loss = .0
        with torch.no_grad():
            for idx, batch in enumerate(iter(dataset_test)):
                batch = unpack_batch(batch=batch, device=device)
                pred = net(**batch)
                loss = test_loss_function(pred, batch['action'])
                # loss = test_loss_function(pred.view(-1), batch['q'])

                test_loss += loss

        if (test_loss / len(dataset_test)) < best_test_loss:
            best_test_loss = (test_loss / len(dataset_test))

        torch.save(net.state_dict(), f'{net_path}/test/test_{epoch_idx+1}.pt')

        print(f'{net.name} test loss {(test_loss/len(dataset_test)):.3f}')
        print(f'{net.name} best test loss {best_test_loss:.3f}')
        writer_test.add_scalar(tag=f'{net.name}/global_loss',
                               scalar_value=(test_loss /
                                             len(dataset_test.dataset)),
                               global_step=(epoch_idx + 1))

    torch.save(optimizer.state_dict(),
               f=f'{net_path}/{optimizer.__class__.__name__}.pt')
    torch.save(scheduler.state_dict(),
               f=f'{net_path}/{scheduler.__class__.__name__}.pt')
    json.dump(vars(args),
              fp=open(f'{net_path}/args.json', 'w'),
              sort_keys=True,
              indent=4)

    writer_train.flush()
    writer_test.flush()
    writer_train.close()
    writer_test.close()

    batch = next(iter(dataset_test))
    batch = unpack_batch(batch=batch, device=device)
    y = net(**batch)
    g = make_dot(y, params=dict(net.named_parameters()))
    g.save(filename=f'{DATE_TIME}_{net.name}.dot', directory=net_path)
    check_call([
        'dot', '-Tpng', '-Gdpi=200', f'{net_path}/{DATE_TIME}_{net.name}.dot',
        '-o', f'{net_path}/{DATE_TIME}_{net.name}.png'
    ])
Example #29
0
def train(opt):
    params = Params(f'configs/{opt.project}.yml')

    if params.num_gpus == 0:
        os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

    if torch.cuda.is_available():
        torch.cuda.manual_seed(42)
    else:
        torch.manual_seed(42)

    opt.saved_path = opt.saved_path + f'/{params.project_name}/'
    opt.log_path = opt.log_path + f'/{params.project_name}/tensorboard/'
    os.makedirs(opt.log_path, exist_ok=True)
    os.makedirs(opt.saved_path, exist_ok=True)

    training_params = {'batch_size': opt.batch_size,
                       'shuffle': True,
                       'drop_last': True,
                       'collate_fn': collater,
                       'num_workers': opt.num_workers}

    val_params = {'batch_size': opt.batch_size,
                  'shuffle': False,
                  'drop_last': True,
                  'collate_fn': collater,
                  'num_workers': opt.num_workers}

    input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536, 1536]
    # input_sizes = [640, 1024, 1535]
    training_set = CocoDataset(root_dir=os.path.join(opt.data_path, params.project_name), set=params.train_set,
                               transform=transforms.Compose([Normalizer(mean=params.mean, std=params.std),
                                                             RandomFlip(flip_ratio=0.5),
                                                             CutOut(n_holes=8, cutout_shape=(32,32)),
                                                            #  MixUp(p=0.5, lambd=0.5),
                                                             Augmenter(),
                                                             Resizer(input_sizes[opt.compound_coef])]))
    training_generator = DataLoader(training_set, **training_params)

    val_set = CocoDataset(root_dir=os.path.join(opt.data_path, params.project_name), set=params.val_set,
                          transform=transforms.Compose([Normalizer(mean=params.mean, std=params.std),
                                                        Resizer(input_sizes[opt.compound_coef])]))
    val_generator = DataLoader(val_set, **val_params)

    model = EfficientDetBackbone(num_classes=len(params.obj_list), compound_coef=opt.compound_coef,
                                 ratios=eval(params.anchors_ratios), scales=eval(params.anchors_scales))

    if opt.optim == 'adamw':
        optimizer = torch.optim.AdamW(model.parameters(), opt.lr)
    else:
        optimizer = torch.optim.SGD(model.parameters(), opt.lr, momentum=0.9, nesterov=True)

    if FP16:
         model, optimizer = amp.initialize(model.cuda(), optimizer, opt_level='O1', verbosity=0)
    # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, verbose=True)
    # scheduler = CosineAnnealingLR(optimizer, T_max=10, eta_min=params["lr_min"], last_epoch=-1)
    scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=1, eta_min=1e-7, last_epoch=-1)
    
    # load last weights
    if opt.load_weights is not None:
        if opt.load_weights.endswith('.pth'):
            weights_path = opt.load_weights
        else:
            weights_path = get_last_weights(opt.saved_path)
        try:
            last_step = int(os.path.basename(weights_path).split('_')[-1].split('.')[0])
        except:
            last_step = 0

        try:
            ret = model.load_state_dict(torch.load(weights_path), strict=False)
        except RuntimeError as e:
            print(f'[Warning] Ignoring {e}')
            print(
                '[Warning] Don\'t panic if you see this, this might be because you load a pretrained weights with different number of classes. The rest of the weights should be loaded already.')

        print(f'[Info] loaded weights: {os.path.basename(weights_path)}, resuming checkpoint from step: {last_step}')
    else:
        last_step = 0
        print('[Info] initializing weights...')
        init_weights(model)

    # freeze backbone if train head_only
    if opt.head_only:
        def freeze_backbone(m):
            classname = m.__class__.__name__
            for ntl in ['EfficientNet', 'BiFPN']:
                if ntl in classname:
                    for param in m.parameters():
                        param.requires_grad = False

        model.apply(freeze_backbone)
        print('[Info] freezed backbone')

    # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
    # apply sync_bn when using multiple gpu and batch_size per gpu is lower than 4
    #  useful when gpu memory is limited.
    # because when bn is disable, the training will be very unstable or slow to converge,
    # apply sync_bn can solve it,
    # by packing all mini-batch across all gpus as one batch and normalize, then send it back to all gpus.
    # but it would also slow down the training by a little bit.
    if params.num_gpus > 1 and opt.batch_size // params.num_gpus < 4:
        model.apply(replace_w_sync_bn)
        use_sync_bn = True
    else:
        use_sync_bn = False

    writer = SummaryWriter(opt.log_path + f'/{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}/')

    # warp the model with loss function, to reduce the memory usage on gpu0 and speedup
    model = ModelWithLoss(model, debug=opt.debug)

    if params.num_gpus > 0:
        model = model.cuda()
        if params.num_gpus > 1:
            model = CustomDataParallel(model, params.num_gpus)
            if use_sync_bn:
                patch_replication_callback(model)

    epoch = 0
    best_loss = 1e5
    best_epoch = 0
    step = max(0, last_step)
    model.train()

    num_iter_per_epoch = len(training_generator)

    try:
        for epoch in range(opt.num_epochs):
            last_epoch = step // num_iter_per_epoch
            if epoch < last_epoch:
                continue

            epoch_loss = []
            progress_bar = tqdm(training_generator)
            for iter, data in enumerate(progress_bar):
                if iter < step - last_epoch * num_iter_per_epoch:
                    progress_bar.update()
                    continue
                try:
                    imgs = data['img']
                    annot = data['annot']

                    if params.num_gpus == 1:
                        # if only one gpu, just send it to cuda:0
                        # elif multiple gpus, send it to multiple gpus in CustomDataParallel, not here
                        imgs = imgs.cuda()
                        annot = annot.cuda()


                    optimizer.zero_grad()
                    cls_loss, reg_loss = model(imgs, annot, obj_list=params.obj_list)
                    cls_loss = cls_loss.mean()
                    reg_loss = reg_loss.mean()

                    loss = cls_loss + reg_loss
                    if loss == 0 or not torch.isfinite(loss):
                        continue

                    # calculate gradient
                    if FP16:
                        with amp.scale_loss(loss, optimizer) as scaled_loss:
                            scaled_loss.backward()
                    else:        
                        loss.backward()

                    # torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
                    optimizer.step()

                    epoch_loss.append(float(loss))

                    progress_bar.set_description(
                        'Step: {}. Epoch: {}/{}. Iteration: {}/{}. Cls loss: {:.5f}. Reg loss: {:.5f}. Total loss: {:.5f}'.format(
                            step, epoch, opt.num_epochs, iter + 1, num_iter_per_epoch, cls_loss.item(),
                            reg_loss.item(), loss.item()))
                    writer.add_scalars('Loss', {'train': loss}, step)
                    writer.add_scalars('Regression_loss', {'train': reg_loss}, step)
                    writer.add_scalars('Classfication_loss', {'train': cls_loss}, step)

                    # log learning_rate
                    current_lr = optimizer.param_groups[0]['lr']
                    writer.add_scalar('learning_rate', current_lr, step)

                    step += 1

                    if step % opt.save_interval == 0 and step > 0:
                        save_checkpoint(model, f'efficientdet-d{opt.compound_coef}_{epoch}_{step}.pth')
                        print('checkpoint...')

                except Exception as e:
                    print('[Error]', traceback.format_exc())
                    print(e)
                    continue
            scheduler.step(np.mean(epoch_loss))

            if epoch % opt.val_interval == 0:
                model.eval()
                loss_regression_ls = []
                loss_classification_ls = []
                for iter, data in enumerate(val_generator):
                    with torch.no_grad():
                        imgs = data['img']
                        annot = data['annot']

                        if params.num_gpus == 1:
                            imgs = imgs.cuda()
                            annot = annot.cuda()

                        cls_loss, reg_loss = model(imgs, annot, obj_list=params.obj_list)
                        cls_loss = cls_loss.mean()
                        reg_loss = reg_loss.mean()

                        loss = cls_loss + reg_loss
                        if loss == 0 or not torch.isfinite(loss):
                            continue

                        loss_classification_ls.append(cls_loss.item())
                        loss_regression_ls.append(reg_loss.item())

                cls_loss = np.mean(loss_classification_ls)
                reg_loss = np.mean(loss_regression_ls)
                loss = cls_loss + reg_loss

                print(
                    'Val. Epoch: {}/{}. Classification loss: {:1.5f}. Regression loss: {:1.5f}. Total loss: {:1.5f}'.format(
                        epoch, opt.num_epochs, cls_loss, reg_loss, loss))
                writer.add_scalars('Loss', {'val': loss}, step)
                writer.add_scalars('Regression_loss', {'val': reg_loss}, step)
                writer.add_scalars('Classfication_loss', {'val': cls_loss}, step)

                if loss + opt.es_min_delta < best_loss:
                    best_loss = loss
                    best_epoch = epoch

                    save_checkpoint(model, f'efficientdet-d{opt.compound_coef}_{epoch}_{step}.pth')

                model.train()

                # Early stopping
                if epoch - best_epoch > opt.es_patience > 0:
                    print('[Info] Stop training at epoch {}. The lowest loss achieved is {}'.format(epoch, best_loss))
                    break
    except KeyboardInterrupt:
        save_checkpoint(model, f'efficientdet-d{opt.compound_coef}_{epoch}_{step}.pth')
        writer.close()
    writer.close()
Example #30
0
def run_training(data_type="screw",
                 model_dir="models",
                 epochs=256,
                 pretrained=True,
                 test_epochs=10,
                 freeze_resnet=20,
                 learninig_rate=0.03,
                 optim_name="SGD",
                 batch_size=64,
                 head_layer=8):
    torch.multiprocessing.freeze_support()
    # TODO: use script params for hyperparameter
    # Temperature Hyperparameter currently not used
    temperature = 0.2
    device = "cuda"

    weight_decay = 0.00003
    momentum = 0.9
    #TODO: use f strings also for the date LOL
    model_name = f"model-{data_type}" + '-{date:%Y-%m-%d_%H_%M_%S}'.format(
        date=datetime.datetime.now())

    #augmentation:
    size = 256
    min_scale = 0.5

    # create Training Dataset and Dataloader
    after_cutpaste_transform = transforms.Compose([])
    after_cutpaste_transform.transforms.append(transforms.ToTensor())
    after_cutpaste_transform.transforms.append(
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]))

    train_transform = transforms.Compose([])
    # train_transform.transforms.append(transforms.RandomResizedCrop(size, scale=(min_scale,1)))
    # train_transform.transforms.append(transforms.GaussianBlur(int(size/10), sigma=(0.1,2.0)))
    train_transform.transforms.append(transforms.Resize((256, 256)))
    train_transform.transforms.append(
        CutPaste(transform=after_cutpaste_transform))
    # train_transform.transforms.append(transforms.ToTensor())

    train_data = MVTecAT("Data",
                         data_type,
                         transform=train_transform,
                         size=int(size * (1 / min_scale)))
    dataloader = DataLoader(train_data,
                            batch_size=batch_size,
                            shuffle=True,
                            num_workers=8,
                            collate_fn=cut_paste_collate_fn,
                            persistent_workers=True,
                            pin_memory=True,
                            prefetch_factor=5)

    # Writer will output to ./runs/ directory by default
    writer = SummaryWriter(Path("logdirs") / model_name)

    # create Model:
    head_layers = [512] * head_layer + [128]
    print(head_layers)
    model = ProjectionNet(pretrained=pretrained, head_layers=head_layers)
    model.to(device)

    if freeze_resnet > 0:
        model.freeze_resnet()

    loss_fn = torch.nn.CrossEntropyLoss()
    if optim_name == "sgd":
        optimizer = optim.SGD(model.parameters(),
                              lr=learninig_rate,
                              momentum=momentum,
                              weight_decay=weight_decay)
        scheduler = CosineAnnealingWarmRestarts(optimizer, epochs)
        #scheduler = None
    elif optim_name == "adam":
        optimizer = optim.Adam(model.parameters(),
                               lr=learninig_rate,
                               weight_decay=weight_decay)
        scheduler = None
    else:
        print(f"ERROR unkown optimizer: {optim_name}")

    step = 0
    import torch.autograd.profiler as profiler
    num_batches = len(dataloader)

    def get_data_inf():
        while True:
            for out in enumerate(dataloader):
                yield out

    dataloader_inf = get_data_inf()
    # From paper: "Note that, unlike conventional definition for an epoch,
    #              we define 256 parameter update steps as one epoch.
    for step in tqdm(range(epochs * 256)):
        epoch = int(step / 256)
        if epoch == freeze_resnet:
            model.unfreeze()

        batch_embeds = []
        batch_idx, data = next(dataloader_inf)
        x1, x2 = data
        x1 = x1.to(device)
        x2 = x2.to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        xc = torch.cat((x1, x2), axis=0)
        embeds, logits = model(xc)

        #         embeds = F.normalize(embeds, p=2, dim=1)
        #         embeds1, embeds2 = torch.split(embeds,x1.size(0),dim=0)
        #         ip = torch.matmul(embeds1, embeds2.T)
        #         ip = ip / temperature

        #         y = torch.arange(0,x1.size(0), device=device)
        #         loss = loss_fn(ip, torch.arange(0,x1.size(0), device=device))

        y = torch.tensor([0, 1], device=device)
        y = y.repeat_interleave(x1.size(0))
        loss = loss_fn(logits, y)

        # regulize weights:
        loss.backward()
        optimizer.step()
        if scheduler is not None:
            scheduler.step(epoch + batch_idx / num_batches)

        writer.add_scalar('loss', loss.item(), step)

        #         predicted = torch.argmax(ip,axis=0)
        predicted = torch.argmax(logits, axis=1)
        #         print(logits)
        #         print(predicted)
        #         print(y)
        accuracy = torch.true_divide(torch.sum(predicted == y),
                                     predicted.size(0))
        writer.add_scalar('acc', accuracy, step)
        if scheduler is not None:
            writer.add_scalar('lr', scheduler.get_last_lr()[0], step)

        # save embed for validation:
        if test_epochs > 0 and epoch % test_epochs == 0:
            batch_embeds.append(embeds.cpu().detach())

        writer.add_scalar('epoch', epoch, step)

        # run tests
        if test_epochs > 0 and epoch % test_epochs == 0:
            # run auc calculation
            #TODO: create dataset only once.
            #TODO: train predictor here or in the model class itself. Should not be in the eval part
            #TODO: we might not want to use the training datat because of droupout etc. but it should give a indecation of the model performance???
            # batch_embeds = torch.cat(batch_embeds)
            # print(batch_embeds.shape)
            model.eval()
            roc_auc = eval_model(model_name,
                                 data_type,
                                 device=device,
                                 save_plots=False,
                                 size=size,
                                 show_training_data=False,
                                 model=model)
            #train_embed=batch_embeds)
            model.train()
            writer.add_scalar('eval_auc', roc_auc, step)

    torch.save(model.state_dict(), model_dir / f"{model_name}.tch")