Ejemplo n.º 1
0
def main(args):
    if args.apex:
        if sys.version_info < (3, 0):
            raise RuntimeError(
                "Apex currently only supports Python 3. Aborting.")
        if amp is None:
            raise RuntimeError(
                "Failed to import apex. Please install apex from https://www.github.com/nvidia/apex "
                "to enable mixed-precision training.")

    if args.output_dir:
        utils.mkdir(args.output_dir)

    vis = utils.Visualize(args)

    utils.init_distributed_mode(args)
    print(args)
    print("torch version: ", torch.__version__)
    print("torchvision version: ", torchvision.__version__)

    device = torch.device(args.device)

    torch.backends.cudnn.benchmark = True

    # Data loading code
    print("Loading data")
    traindir = os.path.join(
        args.data_path, 'train_256' if not args.fast_test else 'val_256_bob')
    valdir = os.path.join(args.data_path, 'val_256_bob')
    normalize = T.Normalize(mean=[0.43216, 0.394666, 0.37645],
                            std=[0.22803, 0.22145, 0.216989])

    print("Loading training data")
    st = time.time()
    cache_path = _get_cache_path(traindir)

    frame_transform_train = utils.make_frame_transform(args.frame_transforms)

    transform_train = torchvision.transforms.Compose([
        #         torchvision.transforms.RandomGrayscale(p=1),
        frame_transform_train,
        T.ToFloatTensorInZeroOne(),
        T.Resize((256, 256)),
        # T.Resize((128, 171)),
        # T.RandomHorizontalFlip(),
        # T.GaussianBlurTransform(),
        normalize,
        # T.RandomCrop((112, 112))
    ])

    def make_dataset(is_train):
        _transform = transform_train if is_train else transform_test

        if 'kinetics' in args.data_path.lower():
            return Kinetics400(traindir if is_train else valdir,
                               frames_per_clip=args.clip_len,
                               step_between_clips=1,
                               transform=transform_train,
                               extensions=('mp4'),
                               frame_rate=args.frame_skip)
        else:
            return VideoList(
                args,
                is_train,
                frame_gap=args.frame_skip,
                transform=_transform,
                # frame_transform=_frame_transform
            )

    if args.cache_dataset and os.path.exists(cache_path):
        print("Loading dataset_train from {}".format(cache_path))
        dataset, _ = torch.load(cache_path)
        dataset.transform = transform_train
    else:
        if args.distributed:
            print("It is recommended to pre-compute the dataset cache "
                  "on a single-gpu first, as it will be faster")
        dataset = make_dataset(is_train=True)

        if args.cache_dataset:
            print("Saving dataset_train to {}".format(cache_path))
            utils.mkdir(os.path.dirname(cache_path))
            utils.save_on_master((dataset, traindir), cache_path)

    if hasattr(dataset, 'video_clips'):
        dataset.video_clips.compute_clips(args.clip_len, 1, frame_rate=15)

    print("Took", time.time() - st)

    print("Loading validation data")
    cache_path = _get_cache_path(valdir)

    transform_test = torchvision.transforms.Compose([
        T.ToFloatTensorInZeroOne(),
        # T.Resize((128, 171)),
        # normalize,
        # T.CenterCrop((112, 112))
        T.Resize((256, 256)),
        normalize
    ])

    if args.cache_dataset and os.path.exists(cache_path):
        print("Loading dataset_test from {}".format(cache_path))
        dataset_test, _ = torch.load(cache_path)
        dataset_test.transform = transform_test
    else:
        if args.distributed:
            print("It is recommended to pre-compute the dataset cache "
                  "on a single-gpu first, as it will be faster")
        # dataset_test = Kinetics400(
        #     valdir,
        #     frames_per_clip=args.clip_len,
        #     step_between_clips=1,
        #     transform=transform_test,
        #     extensions=('mp4')
        # )
        dataset_test = make_dataset(is_train=False)

        if args.cache_dataset:
            print("Saving dataset_test to {}".format(cache_path))
            utils.mkdir(os.path.dirname(cache_path))
            utils.save_on_master((dataset_test, valdir), cache_path)

    if hasattr(dataset, 'video_clips'):
        dataset_test.video_clips.compute_clips(args.clip_len, 1, frame_rate=15)

    def make_data_sampler(is_train, dataset):
        if hasattr(dataset, 'video_clips'):
            _sampler = RandomClipSampler if is_train else UniformClipSampler
            return _sampler(dataset.video_clips, args.clips_per_video)
        else:
            return torch.utils.data.sampler.RandomSampler(
                dataset) if is_train else None

    print("Creating data loaders")
    train_sampler, test_sampler = make_data_sampler(True, dataset), \
                                    make_data_sampler(False, dataset_test)
    # train_sampler = train_sampler(dataset.video_clips, args.clips_per_video)
    # test_sampler = test_sampler(dataset_test.video_clips, args.clips_per_video)

    if args.distributed:
        train_sampler = DistributedSampler(train_sampler)
        test_sampler = DistributedSampler(test_sampler)

    data_loader = torch.utils.data.DataLoader(dataset,
                                              batch_size=args.batch_size,
                                              sampler=train_sampler,
                                              num_workers=args.workers,
                                              pin_memory=True,
                                              collate_fn=collate_fn)

    data_loader_test = torch.utils.data.DataLoader(dataset_test,
                                                   batch_size=args.batch_size,
                                                   sampler=test_sampler,
                                                   num_workers=args.workers,
                                                   pin_memory=True,
                                                   collate_fn=collate_fn)

    print("Creating model")
    import resnet
    import timecycle as tc
    # model = resnet.__dict__[args.model](pretrained=args.pretrained)
    model = tc.TimeCycle(args)

    # utils.compute_RF_numerical(model.resnet, torch.ones(1, 3, 1, 112, 112).numpy())
    # import pdb; pdb.set_trace()
    # print(utils.compute_RF_numerical(model,img_np))

    model.to(device)

    if args.distributed and args.sync_bn:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    criterion = nn.CrossEntropyLoss()

    lr = args.lr * args.world_size
    # optimizer = torch.optim.SGD(
    #     model.parameters(), lr=lr, momentum=args.momentum, weight_decay=args.weight_decay)
    optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
    if args.apex:
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.apex_opt_level)

    # convert scheduler to be per iteration, not per epoch, for warmup that lasts
    # between different epochs
    warmup_iters = args.lr_warmup_epochs * len(data_loader)
    lr_milestones = [len(data_loader) * m for m in args.lr_milestones]
    lr_scheduler = WarmupMultiStepLR(optimizer,
                                     milestones=lr_milestones,
                                     gamma=args.lr_gamma,
                                     warmup_iters=warmup_iters,
                                     warmup_factor=1e-5)

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpu])
        model_without_ddp = model.module

    if args.data_parallel:
        model = torch.nn.parallel.DataParallel(model)
        model_without_ddp = model.module

    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        args.start_epoch = checkpoint['epoch'] + 1

    if args.test_only:
        evaluate(model, criterion, data_loader_test, device=device)
        return

    print("Start training")
    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        train_one_epoch(model,
                        criterion,
                        optimizer,
                        lr_scheduler,
                        data_loader,
                        device,
                        epoch,
                        args.print_freq,
                        args.apex,
                        vis=vis)
        # evaluate(model, criterion, data_loader_test, device=device)
        if args.output_dir:
            checkpoint = {
                'model': model_without_ddp.state_dict(),
                'optimizer': optimizer.state_dict(),
                'lr_scheduler': lr_scheduler.state_dict(),
                'epoch': epoch,
                'args': args
            }
            utils.save_on_master(
                checkpoint,
                os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
            utils.save_on_master(
                checkpoint, os.path.join(args.output_dir, 'checkpoint.pth'))

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
Ejemplo n.º 2
0
def train_model_on_dataset(rank, cfg):
    dist_rank = rank
    # print(dist_rank)
    dist.init_process_group(backend="nccl", rank=dist_rank,
                            world_size=cfg.num_gpu,
                            init_method="env://")
    torch.cuda.set_device(rank)
    cudnn.benchmark = True
    dataset = CityFlowNLDataset(cfg, build_transforms(cfg))

    model = MyModel(cfg, len(dataset.nl), dataset.nl.word_to_idx['<PAD>'], norm_layer=nn.SyncBatchNorm, num_colors=len(CityFlowNLDataset.colors), num_types=len(CityFlowNLDataset.vehicle_type) - 2).cuda()
    model = DistributedDataParallel(model, device_ids=[rank],
                                    output_device=rank,
                                    broadcast_buffers=cfg.num_gpu > 1, find_unused_parameters=False)
    optimizer = torch.optim.Adam(
            params=model.parameters(),
            lr=cfg.TRAIN.LR.BASE_LR, weight_decay=0.00003)
    lr_scheduler = WarmupMultiStepLR(optimizer,
                            milestones=cfg.TRAIN.STEPS,
                            gamma=cfg.TRAIN.LR.WEIGHT_DECAY,
                            warmup_factor=cfg.TRAIN.WARMUP_FACTOR,
                            warmup_iters=cfg.TRAIN.WARMUP_EPOCH)
    color_loss = LabelSmoothingLoss(len(dataset.colors), 0.1)
    vehicle_loss = LabelSmoothingLoss(len(dataset.vehicle_type) - 2, 0.1)
    if cfg.resume_epoch > 0:
        model.load_state_dict(torch.load(f'save/{cfg.resume_epoch}.pth'))
        optimizer.load_state_dict(torch.load(f'save/{cfg.resume_epoch}_optim.pth'))
        lr_scheduler.last_epoch = cfg.resume_epoch
        lr_scheduler.step()
        if rank == 0:
            print(f'resume from {cfg.resume_epoch} pth file, starting {cfg.resume_epoch+1} epoch')
        cfg.resume_epoch += 1

    # loader = DataLoader(dataset, batch_size=cfg.TRAIN.BATCH_SIZE, shuffle=True, num_workers=cfg.TRAIN.NUM_WORKERS)
    train_sampler = DistributedSampler(dataset)
    loader = DataLoader(dataset, batch_size=cfg.TRAIN.BATCH_SIZE //cfg.num_gpu,
                            num_workers=cfg.TRAIN.NUM_WORKERS // cfg.num_gpu,# shuffle=True,
                            sampler=train_sampler, pin_memory=True)
    for epoch in range(cfg.resume_epoch, cfg.TRAIN.EPOCH):
        losses = 0.
        losses_color = 0.
        losses_types = 0.
        losses_nl_color = 0.
        losses_nl_types = 0.
        precs = 0.
        train_sampler.set_epoch(epoch)
        for idx, (nl, frame, label, act_map, color_label, type_label, nl_color_label, nl_type_label) in enumerate(loader):
            # print(nl.shape)
            # print(global_img.shape)
            # print(local_img.shape)
            nl = nl.cuda(non_blocking=True)
            label = label.cuda(non_blocking=True)
            act_map = act_map.cuda(non_blocking=True)
            # global_img, local_img = global_img.cuda(), local_img.cuda()
            nl = nl.transpose(1, 0)
            frame = frame.cuda(non_blocking=True)
            color_label = color_label.cuda(non_blocking=True)
            type_label = type_label.cuda(non_blocking=True)
            nl_color_label = nl_color_label.cuda(non_blocking=True)
            nl_type_label = nl_type_label.cuda(non_blocking=True)
            output, color, types, nl_color, nl_types = model(nl, frame, act_map)
            
            # loss = sampling_loss(output, label, ratio=5)
            # loss = F.binary_cross_entropy_with_logits(output, label)
            total_num_pos = reduce_sum(label.new_tensor([label.sum()])).item()
            num_pos_avg_per_gpu = max(total_num_pos / float(cfg.num_gpu), 1.0)

            loss = sigmoid_focal_loss(output, label, reduction='sum') / num_pos_avg_per_gpu
            loss_color = color_loss(color, color_label) * cfg.TRAIN.ALPHA_COLOR
            loss_type = vehicle_loss(types, type_label) * cfg.TRAIN.ALPHA_TYPE
            loss_nl_color = color_loss(nl_color, nl_color_label) * cfg.TRAIN.ALPHA_NL_COLOR
            loss_nl_type = vehicle_loss(nl_types, nl_type_label) * cfg.TRAIN.ALPHA_NL_TYPE
            loss_total = loss + loss_color + loss_type + loss_nl_color + loss_nl_type
            optimizer.zero_grad()
            loss_total.backward()
            # torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
            optimizer.step()
            
            losses += loss.item()
            losses_color += loss_color.item()
            losses_types += loss_type.item()
            losses_nl_color += loss_nl_color.item()
            losses_nl_types += loss_nl_type.item()
            # precs += recall.item()
            
            if rank == 0 and idx % cfg.TRAIN.PRINT_FREQ == 0:
                pred = (output.sigmoid() > 0.5)
                # print((pred == label).sum())
                pred = (pred == label) 
                recall = (pred * label).sum() / label.sum()
                ca = (color.argmax(dim=1) == color_label)
                ca = ca.sum().item() / ca.numel()
                ta = (types.argmax(dim=1) == type_label)
                ta = ta.sum().item() / ta.numel()
                # accu = pred.sum().item() / pred.numel()
                lr = optimizer.param_groups[0]['lr']
                print(f'epoch: {epoch},', 
                f'lr: {lr}, step: {idx}/{len(loader)},',
                f'loss: {losses / (idx + 1):.4f},', 
                f'loss color: {losses_color / (idx + 1):.4f},',
                f'loss type: {losses_types / (idx + 1):.4f},',
                f'loss nl color: {losses_nl_color / (idx + 1):.4f},',
                f'loss nl type: {losses_nl_types / (idx + 1):.4f},',
                f'recall: {recall.item():.4f}, c_accu: {ca:.4f}, t_accu: {ta:.4f}')
        lr_scheduler.step()
        if rank == 0:
            if not os.path.exists('save'):
                os.mkdir('save')
            torch.save(model.state_dict(), f'save/{epoch}.pth')
            torch.save(optimizer.state_dict(), f'save/{epoch}_optim.pth')
Ejemplo n.º 3
0
def main(args):
    if args.apex and amp is None:
        raise RuntimeError(
            "Failed to import apex. Please install apex from https://www.github.com/nvidia/apex "
            "to enable mixed-precision training.")

    if args.output_dir:
        utils.mkdir(args.output_dir)

    utils.init_distributed_mode(args)
    print(args)
    print("torch version: ", torch.__version__)
    print("torchvision version: ", torchvision.__version__)

    device = torch.device(args.device)

    torch.backends.cudnn.benchmark = True

    # Data loading code
    print("Loading data")
    traindir = os.path.join(args.data_path, args.train_dir)
    valdir = os.path.join(args.data_path, args.val_dir)
    normalize = T.Normalize(mean=[0.43216, 0.394666, 0.37645],
                            std=[0.22803, 0.22145, 0.216989])

    print("Loading training data")
    st = time.time()
    cache_path = _get_cache_path(traindir)
    transform_train = torchvision.transforms.Compose([
        T.ToFloatTensorInZeroOne(),
        T.Resize((128, 171)),
        T.RandomHorizontalFlip(), normalize,
        T.RandomCrop((112, 112))
    ])

    if args.cache_dataset and os.path.exists(cache_path):
        print("Loading dataset_train from {}".format(cache_path))
        dataset, _ = torch.load(cache_path)
        dataset.transform = transform_train
    else:
        if args.distributed:
            print("It is recommended to pre-compute the dataset cache "
                  "on a single-gpu first, as it will be faster")
        dataset = torchvision.datasets.Kinetics400(
            traindir,
            frames_per_clip=args.clip_len,
            step_between_clips=1,
            transform=transform_train,
            frame_rate=15)
        if args.cache_dataset:
            print("Saving dataset_train to {}".format(cache_path))
            utils.mkdir(os.path.dirname(cache_path))
            utils.save_on_master((dataset, traindir), cache_path)

    print("Took", time.time() - st)

    print("Loading validation data")
    cache_path = _get_cache_path(valdir)

    transform_test = torchvision.transforms.Compose([
        T.ToFloatTensorInZeroOne(),
        T.Resize((128, 171)), normalize,
        T.CenterCrop((112, 112))
    ])

    if args.cache_dataset and os.path.exists(cache_path):
        print("Loading dataset_test from {}".format(cache_path))
        dataset_test, _ = torch.load(cache_path)
        dataset_test.transform = transform_test
    else:
        if args.distributed:
            print("It is recommended to pre-compute the dataset cache "
                  "on a single-gpu first, as it will be faster")
        dataset_test = torchvision.datasets.Kinetics400(
            valdir,
            frames_per_clip=args.clip_len,
            step_between_clips=1,
            transform=transform_test,
            frame_rate=15)
        if args.cache_dataset:
            print("Saving dataset_test to {}".format(cache_path))
            utils.mkdir(os.path.dirname(cache_path))
            utils.save_on_master((dataset_test, valdir), cache_path)

    print("Creating data loaders")
    train_sampler = RandomClipSampler(dataset.video_clips,
                                      args.clips_per_video)
    test_sampler = UniformClipSampler(dataset_test.video_clips,
                                      args.clips_per_video)
    if args.distributed:
        train_sampler = DistributedSampler(train_sampler)
        test_sampler = DistributedSampler(test_sampler)

    data_loader = torch.utils.data.DataLoader(dataset,
                                              batch_size=args.batch_size,
                                              sampler=train_sampler,
                                              num_workers=args.workers,
                                              pin_memory=True,
                                              collate_fn=collate_fn)

    data_loader_test = torch.utils.data.DataLoader(dataset_test,
                                                   batch_size=args.batch_size,
                                                   sampler=test_sampler,
                                                   num_workers=args.workers,
                                                   pin_memory=True,
                                                   collate_fn=collate_fn)

    print("Creating model")
    model = torchvision.models.video.__dict__[args.model](
        pretrained=args.pretrained)
    model.to(device)
    if args.distributed and args.sync_bn:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    criterion = nn.CrossEntropyLoss()

    lr = args.lr * args.world_size
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    if args.apex:
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.apex_opt_level)

    # convert scheduler to be per iteration, not per epoch, for warmup that lasts
    # between different epochs
    warmup_iters = args.lr_warmup_epochs * len(data_loader)
    lr_milestones = [len(data_loader) * m for m in args.lr_milestones]
    lr_scheduler = WarmupMultiStepLR(optimizer,
                                     milestones=lr_milestones,
                                     gamma=args.lr_gamma,
                                     warmup_iters=warmup_iters,
                                     warmup_factor=1e-5)

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpu])
        model_without_ddp = model.module

    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        args.start_epoch = checkpoint['epoch'] + 1

    if args.test_only:
        evaluate(model, criterion, data_loader_test, device=device)
        return

    print("Start training")
    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader,
                        device, epoch, args.print_freq, args.apex)
        evaluate(model, criterion, data_loader_test, device=device)
        if args.output_dir:
            checkpoint = {
                'model': model_without_ddp.state_dict(),
                'optimizer': optimizer.state_dict(),
                'lr_scheduler': lr_scheduler.state_dict(),
                'epoch': epoch,
                'args': args
            }
            utils.save_on_master(
                checkpoint,
                os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
            utils.save_on_master(
                checkpoint, os.path.join(args.output_dir, 'checkpoint.pth'))

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
Ejemplo n.º 4
0
def main(args):
    root = args.root
    results_dir = args.results_dir
    save_path = os.path.join(root, results_dir)
    print('root is', root)
    print('save_path is:', save_path)
    os.makedirs(save_path, exist_ok=True)
    json_file_name = os.path.join(save_path, 'args.json')
    with open(json_file_name, 'w') as fp:
        json.dump(dict(args._get_kwargs()), fp, sort_keys=True, indent=4)
    checkpoints_path = os.path.join(save_path, 'checkpoints')
    os.makedirs(checkpoints_path, exist_ok=True)
    sample_output_path = os.path.join(save_path, 'output')
    os.makedirs(sample_output_path, exist_ok=True)

    log_file = os.path.join(save_path, 'log.txt')

    config_logging(log_file)
    device = "cuda"
    logging.info('====>  args{} '.format(args))
    num_workers = args.num_workers

    batch_size = args.batch_size
    dataset_path = args.dataset_path
    print('dataset', dataset_path)

    train_ds = CIFAR100(root=dataset_path,
                        train=True,
                        download=False,
                        transform=transform_train_cifar)
    test_ds = CIFAR100(root=dataset_path,
                       train=False,
                       download=False,
                       transform=transform_test_cifar)

    obtain_indices = get_indices
    end_class = args.end_class

    classes = [i for i in range(100)]

    print('data cls', classes)

    testing_idx = obtain_indices(test_ds, classes, is_training=False)
    training_idx = obtain_indices(train_ds, classes, is_training=True)
    codes_path = os.path.join(root, args.codes_path)

    # print('code_path', codes_path)
    #
    # codes_ds = CodesNpzDataset(codes_path)
    # codes_training_idx = obtain_indices(codes_ds, classes, is_training=True)
    #
    # codes_loader = DataLoader(codes_ds, batch_size=batch_size, num_workers=num_workers, drop_last=False, sampler=SubsetRandomSampler(codes_training_idx))
    train_loader = DataLoader(train_ds,
                              batch_size=batch_size,
                              sampler=SubsetRandomSampler(training_idx),
                              num_workers=num_workers,
                              drop_last=False)
    test_loader = DataLoader(test_ds,
                             batch_size=batch_size,
                             sampler=SubsetRandomSampler(testing_idx),
                             num_workers=num_workers,
                             drop_last=False)

    model = resnet18(3, 100).to(device)

    #
    # mode_pt_path = args.model_pt
    # model_pt = torch.load(mode_pt_path)
    # model.load_state_dict(model_pt)

    opt = optim.SGD(model.parameters(),
                    lr=args.lr,
                    momentum=0.9,
                    weight_decay=5e-4)
    # AE = VQVAE(n_embed=args.n_emb, embed_dim=args.dim_emb).to(device)
    # AE_pt_path = os.path.join(root, args.AE_pt)
    # AE_pt = torch.load(AE_pt_path)
    # AE.load_state_dict(AE_pt)

    MILESTONES = [60, 120, 160]
    warmupMultiStepLR = WarmupMultiStepLR(opt,
                                          milestones=MILESTONES,
                                          gamma=0.2,
                                          warmup_iters=args.warm)

    best_acc = 0.0
    scheduler = warmupMultiStepLR

    for i in range(args.epoch):

        train_pure_resnet(i, train_loader, model, opt, device, classes)
        tmp_acc = test_pure_resnet(i, test_loader, model, device, classes)
        if tmp_acc > best_acc:
            best_acc = tmp_acc
            logging.info('====>  Epoch{}: best_acc {} '.format(i, best_acc))
            pt_path = os.path.join(save_path,
                                   f"checkpoints/classifier_best.pt")
            torch.save(model.state_dict(), pt_path)
        scheduler.step()