Exemple #1
0
def make_optimizer(model, optimizer_name="AdamW", sam=False):
    optimizer_grouped_parameters = get_optimizer_params(model)
    kwargs = {
        'lr': 5e-5,
        'weight_decay': 0.01,
        # 'betas': (0.9, 0.98),
        # 'eps': 1e-06
    }
    if sam:
        if optimizer_name == "LAMB":
            optimizer = Lamb(optimizer_grouped_parameters, **kwargs)
            return optimizer
        elif optimizer_name == "Adam":
            from torch.optim import Adam
            optimizer = Adam(optimizer_grouped_parameters, **kwargs)
            return optimizer
        elif optimizer_name == "AdamW":
            optimizer = transformers.AdamW(optimizer_grouped_parameters,
                                           **kwargs)
            return optimizer
        else:
            raise Exception('Unknown optimizer: {}'.format(optimizer_name))
    else:
        if optimizer_name == "LAMB":
            base_optimizer = Lamb
            optimizer = SAM(optimizer_grouped_parameters,
                            base_optimizer,
                            rho=0.05,
                            **kwargs)
            return optimizer
        elif optimizer_name == "Adam":
            from torch.optim import Adam
            base_optimizer = Adam
            optimizer = SAM(optimizer_grouped_parameters,
                            base_optimizer,
                            rho=0.05,
                            **kwargs)
            return optimizer
        elif optimizer_name == "AdamW":
            from transformers import AdamW
            base_optimizer = AdamW
            optimizer = SAM(optimizer_grouped_parameters,
                            base_optimizer,
                            rho=0.05,
                            **kwargs)
            return optimizer
        else:
            raise Exception('Unknown optimizer: {}'.format(optimizer_name))
Exemple #2
0
def main():
    args = get_cli_args()
    validate_cli_args(args)
    alphas = np.array(args.alphas)
    beta = np.array(args.beta)**2

    mean_prior = np.array([180., 50., 0.])
    Sigma_prior = 1e-12 * np.eye(3, 3)
    initial_state = Gaussian(mean_prior, Sigma_prior)

    if args.input_data_file:
        data = load_data(args.input_data_file)
    elif args.num_steps:
        # Generate data, assuming `--num-steps` was present in the CL args.
        data = generate_input_data(initial_state.mu.T, args.num_steps,
                                   args.num_landmarks_per_side,
                                   args.max_obs_per_time_step, alphas, beta,
                                   args.dt)
    else:
        raise RuntimeError('')

    should_show_plots = True if args.animate else False
    should_write_movie = True if args.movie_file else False
    should_update_plots = True if should_show_plots or should_write_movie else False

    field_map = FieldMap(args.num_landmarks_per_side)

    fig = get_plots_figure(should_show_plots, should_write_movie)
    movie_writer = get_movie_writer(should_write_movie, 'Simulation SLAM',
                                    args.movie_fps, args.plot_pause_len)
    progress_bar = FillingCirclesBar('Simulation Progress', max=data.num_steps)

    data = load_data("slam-evaluation-input.npy")

    slam = SAM(beta, alphas, initial_state)

    with movie_writer.saving(
            fig, args.movie_file,
            data.num_steps) if should_write_movie else get_dummy_context_mgr():
        for t in range(data.num_steps):
            # Used as means to include the t-th time-step while plotting.
            tp1 = t + 1

            # Control at the current step.
            u = data.filter.motion_commands[t]
            # Observation at the current step.
            z = data.filter.observations[t]
            # print(data.filter.observations.shape)

            slam.predict(u)
            trajectory, landmarks = slam.update(z)

            progress_bar.next()
            if not should_update_plots:
                continue

            plt.cla()
            plot_field(field_map, z, slam.lm_positions,
                       slam.lm_correspondences)
            plot_robot(data.debug.real_robot_path[t])
            plot_observations(data.debug.real_robot_path[t],
                              data.debug.noise_free_observations[t],
                              data.filter.observations[t])

            plt.plot(data.debug.real_robot_path[1:tp1, 0],
                     data.debug.real_robot_path[1:tp1, 1], 'm')
            plt.plot(data.debug.noise_free_robot_path[1:tp1, 0],
                     data.debug.noise_free_robot_path[1:tp1, 1], 'g')

            plt.plot([data.debug.real_robot_path[t, 0]],
                     [data.debug.real_robot_path[t, 1]], '*r')
            plt.plot([data.debug.noise_free_robot_path[t, 0]],
                     [data.debug.noise_free_robot_path[t, 1]], '*g')

            # TODO plot SLAM soltion
            plt.plot(np.array(trajectory)[:, 0], np.array(trajectory)[:, 1])
            plt.scatter(np.array(landmarks)[:, 0], np.array(landmarks)[:, 1])

            # print(t)

            # for lm in slam.lm_positions:
            #     # print(len(lm))
            #     if len(lm)>5:
            #         lm_mu, lm_sigma = get_gaussian_statistics_xy(np.array(lm[-5:]))
            #         # print('lm_mu',lm_mu)
            #         # print('lm_sigma',lm_sigma)
            #         # print('plot lm')
            #         plot2dcov(lm_mu, lm_sigma, 3, 50)

            if should_show_plots:
                # Draw all the plots and pause to create an animation effect.
                plt.draw()
                plt.pause(args.plot_pause_len)

            if should_write_movie:
                movie_writer.grab_frame()

    progress_bar.finish()

    plt.show(block=True)
def train_loop(folds, fold):

    seed_torch(seed=CFG.seed)

    LOGGER.info(f'========== fold: {fold} training ============')

    # ======================================================
    # loader
    # ======================================================
    trn_idx = folds[folds['fold'] != fold].index
    val_idx = folds[folds['fold'] == fold].index

    train_folds = folds.loc[trn_idx].reset_index(drop=True)
    valid_folds = folds.loc[val_idx].reset_index(drop=True)

    train_dataset = TrainDataset(train_folds,
                                 transform=get_transforms(data='train'))
    valid_dataset = TrainDataset(valid_folds,
                                 transform=get_transforms(data='valid'))

    train_loader = DataLoader(train_dataset,
                              batch_size=CFG.batch_size,
                              shuffle=True,
                              num_workers=CFG.num_workers,
                              pin_memory=True,
                              drop_last=False)
    valid_loader = DataLoader(valid_dataset,
                              batch_size=CFG.batch_size,
                              shuffle=False,
                              num_workers=CFG.num_workers,
                              pin_memory=True,
                              drop_last=False)

    # ===============================================
    # scheduler
    # ===============================================
    def get_scheduler(optimizer):
        if CFG.scheduler == 'ReduceLROnPlateau':
            scheduler = ReduceLROnPlateau(optimizer,
                                          mode='min',
                                          factor=CFG.factor,
                                          patience=CFG.patience,
                                          verbose=True,
                                          eps=CFG.eps)
        elif CFG.scheduler == 'CosineAnnealingLR':
            scheduler = CosineAnnealingLR(optimizer,
                                          T_max=CFG.T_max,
                                          eta_min=CFG.min_lr,
                                          last_epoch=-1)
        elif CFG.scheduler == 'CosineAnnealingWarmRestarts':
            scheduler = CosineAnnealingWarmRestarts(optimizer,
                                                    T_0=CFG.T_0,
                                                    T_mult=1,
                                                    eta_min=CFG.min_lr,
                                                    last_epoch=-1)
        return scheduler

    # ===============================================
    # model & optimizer
    # ===============================================
    model = CustomEfficientNetB3ns(CFG.model_name, pretrained=True)

    # 最初の3epochはclassifier層以外全て凍結する。
    for name, param in model.model.named_parameters():
        if 'classifier' not in name:
            param.requires_grad = False

    model.to(device)

    base_optimizer = Adam
    optimizer = SAM(model.parameters(),
                    base_optimizer,
                    lr=CFG.lr_1,
                    weight_decay=CFG.weight_decay,
                    amsgrad=False)

    scheduler = get_scheduler(optimizer)

    # ===============================================
    # apex
    # ===============================================
    if CFG.apex:
        model.optimizer = amp.initialize(model,
                                         optimizer,
                                         opt_level='O1',
                                         verbosity=0)

    # ===============================================
    # loop
    # ===============================================
    def get_loss_train():
        if CFG.loss_train == 'CrossEntropyLoss':
            loss_train = nn.CrossEntropyLoss()
        elif CFG.loss_train == 'LabelSmoothing':
            loss_train = LabelSmoothingLoss(classes=CFG.target_size,
                                            smoothing=CFG.smooth)
        elif CFG.loss_train == 'FocalLoss':
            loss_train = FocalLoss().to(device)
        elif CFG.loss_train == 'FocalCosineLoss':
            loss_train = FocalCosineLoss()
        elif CFG.loss_train == 'SymmetricCrossEntropyLoss':
            loss_train = SymmetricCrossEntropy().to(device)
        elif CFG.loss_train == 'BiTemperedLoss':
            loss_train = BiTemperedLogisticLoss(t1=CFG.t1,
                                                t2=CFG.t2,
                                                smoothing=CFG.smooth)
        elif CFG.loss_train == 'TaylorCrossEntropyLoss':
            loss_train = TaylorCrossEntropyLoss(n=6, smoothing=CFG.smoothing)
        return loss_train

    loss_train = get_loss_train()
    LOGGER.info(f'loss_train: {loss_train}')
    loss_metric = nn.CrossEntropyLoss()

    best_score = 0.
    best_loss = np.inf

    for epoch in range(CFG.epochs):

        start_time = time.time()

        if epoch == 1:

            # 2epoch目に重みを全て解凍する
            for param in model.model.parameters():
                param.requires_grad = True

            # 学習率を4e-3から4e-4に落とす
            base_optimizer = Adam
            optimizer = SAM(model.parameters(),
                            base_optimizer,
                            lr=CFG.lr_2,
                            weight_decay=CFG.weight_decay,
                            amsgrad=False)
            scheduler = get_scheduler(optimizer)

            LOGGER.info('requires_grad of all parameters are unlocked')

        # train
        avg_loss = train_fn(train_loader, model, loss_train, loss_metric,
                            optimizer, epoch, scheduler, device)

        # eval
        avg_val_loss, preds = valid_fn(valid_loader, model, loss_metric,
                                       device)
        valid_labels = valid_folds[CFG.target_col].values

        if isinstance(scheduler, ReduceLROnPlateau):
            scheduler.step(avg_val_loss)
        elif isinstance(scheduler, CosineAnnealingLR):
            scheduler.step()
        elif isinstance(scheduler, CosineAnnealingWarmRestarts):
            scheduler.step()

        # scoring
        score = get_score(valid_labels, preds.argmax(1))

        elapsed = time.time() - start_time

        LOGGER.info(
            f'Epoch {epoch+1} - avg_train_loss: {avg_loss:.4f} avg_val_loss: {avg_val_loss:.4f} time: {elapsed:.0f}s'
        )
        LOGGER.info(f'Epoch {epoch+1} - Accuracy: {score}')

        if score > best_score:
            best_score = score
            LOGGER.info(
                f'Epoch {epoch+1} - Save Best Score: {best_score:.4f} Model')
            torch.save({
                'model': model.state_dict(),
                'preds': preds
            }, OUTPUT_DIR + f'{CFG.model_name}_fold{fold}_best.pth')

        # inference用に全て保存しておく
        torch.save({'model': model.state_dict()}, OUTPUT_DIR +
                   f'{CFG.model_name}_fold{fold}_epoch{epoch+1}.pth')

    check_point = torch.load(OUTPUT_DIR +
                             f'{CFG.model_name}_fold{fold}_best.pth')
    valid_folds[[str(c) for c in range(5)]] = check_point['preds']
    valid_folds['preds'] = check_point['preds'].argmax(1)

    return valid_folds
def main_train(config, checkpoint_dir=None):
    global args, best_corr
    best_corr = 0.0

    args.store_name = '{}'.format(args.model)
    args.store_name = args.store_name + datetime.now().strftime('_%m-%d_%H-%M-%S')
    args.start_epoch = 0

    # check_rootfolders(args)
    if args.model == 'Baseline':
        model = Baseline()
    elif args.model == 'TCFPN':
        model = TCFPN(layers=[48, 64, 96], in_channels=(2048 + 128), num_classes=15, kernel_size=11)
    
    model = torch.nn.DataParallel(model).cuda()

    if config['optimizer'] == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
    elif config['optimizer'] == 'adamw':
        optimizer = torch.optim.AdamW(model.parameters(), lr=config['lr'])
    
    # custom optimizer
    if args.use_sam:
        base_optim = torch.optim.Adam
        optimizer = SAM(model.parameters(), base_optim, lr=config['lr'])
    # custom lr scheduler
    if args.use_cos_wr:
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=args.cos_wr_t0,T_mult=args.cos_wr_t_mult)
    elif args.use_cos:
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.cos_t_max)
    # SWA
    if args.use_swa:
        swa_model = torch.optim.swa_utils.AveragedModel(model)
        swa_scheduler = torch.optim.swa_utils.SWALR(optimizer, swa_lr=config['lr'])

    # ckpt structure {epoch, state_dict, optimizer, best_corr}
    # if args.resume and os.path.isfile(args.resume):
    #     print('Load checkpoint:', args.resume)
    #     ckpt = torch.load(args.resume)
    #     args.start_epoch = ckpt['epoch']
    #     best_corr = ckpt['best_corr']
    #     model.load_state_dict(ckpt['state_dict'])
    #     optimizer.load_state_dict(ckpt['optimizer'])
    #     print('Loaded ckpt at epoch:', args.start_epoch)
    if checkpoint_dir:
        model_state, optimizer_state = torch.load(
            os.path.join(checkpoint_dir, "checkpoint"))
        model.load_state_dict(model_state)
        optimizer.load_state_dict(optimizer_state)


    # initialize datasets
    train_loader = torch.utils.data.DataLoader(
        dataset=EEV_Dataset(
            csv_path=args.train_csv,
            vidmap_path=args.train_vidmap,
            image_feat_path=args.image_features,
            audio_feat_path=args.audio_features,
            mode='train', lpfilter=args.lp_filter
        ),
        batch_size=config['batch_size'], shuffle=True,
        num_workers=args.workers, pin_memory=False,
        drop_last=True
    )

    val_loader = torch.utils.data.DataLoader(
        dataset=EEV_Dataset(
            csv_path=args.val_csv,
            vidmap_path=args.val_vidmap,
            image_feat_path=args.image_features,
            audio_feat_path=args.audio_features,
            mode='val'
        ),
        batch_size=None, shuffle=False,
        num_workers=args.workers, pin_memory=False
    )

    accuracy = correlation
    # with open(os.path.join(args.root_log, args.store_name, 'args.txt'), 'w') as f:
    #     f.write(str(args))
    
    # tb_writer = SummaryWriter(log_dir=os.path.join(args.root_log, args.store_name))

    for epoch in range(args.start_epoch, args.epochs):
        # train
        train(train_loader, model, optimizer, epoch, None, None)
        # do lr scheduling after epoch
        if args.use_swa and epoch >= args.swa_start:
            print('swa stepping...')
            swa_model.update_parameters(model)
            swa_scheduler.step()
        elif args.use_cos_wr:
            print('cos warm restart (T0:{} Tm:{}) stepping...'.format(args.cos_wr_t0, args.cos_wr_t_mult))
            scheduler.step()
        elif args.use_cos:
            print('cos (Tmax:{}) stepping...'.format(args.cos_t_max))
            scheduler.step()
        
        # validate
        if args.use_swa and epoch >= args.swa_start:
            # validate use swa model
            corr, loss = validate(val_loader, swa_model, accuracy, epoch, None, None)
        else:
            corr, loss = validate(val_loader, model, accuracy, epoch, None, None)
        is_best = corr > best_corr
        best_corr = max(corr, best_corr)
        # tb_writer.add_scalar('acc/validate_corr_best', best_corr, epoch)
        # output_best = 'Best corr: %.4f\n' % (best_corr)
        # print(output_best)
        # save_checkpoint({
        #     'epoch': epoch + 1,
        #     'state_dict': model.state_dict(),
        #     'optimizer': optimizer.state_dict(),
        #     'best_corr': best_corr,
        # }, is_best)
        with tune.checkpoint_dir(epoch) as checkpoint_dir:
            path = os.path.join(checkpoint_dir, "checkpoint")
            if is_best:
                path = os.path.join(checkpoint_dir, "checkpoint_best")
            torch.save((model.state_dict(), optimizer.state_dict()), path)
        tune.report(loss=loss, accuracy=corr, best_corr=best_corr)
def main_train():
    global args, best_corr

    args.store_name = '{}'.format(args.model)
    args.store_name = args.store_name + datetime.now().strftime(
        '_%m-%d_%H-%M-%S')
    args.start_epoch = 0

    if not args.val_only:
        check_rootfolders(args)
    if args.model == 'Baseline':
        if args.cls_indices:
            model = Baseline(args.img_feat_size,
                             args.au_feat_size,
                             num_classes=len(args.cls_indices))
        else:
            print('Feature size:', args.img_feat_size, args.au_feat_size)
            model = Baseline(args.img_feat_size, args.au_feat_size)
    elif args.model == 'TCFPN':
        model = TCFPN(layers=[48, 64, 96],
                      in_channels=(128),
                      num_classes=15,
                      kernel_size=11)
    elif args.model == 'BaseAu':
        model = Baseline_Au(args.au_feat_size)
    elif args.model == 'BaseImg':
        model = Baseline_Img(args.img_feat_size)
    elif args.model == 'EmoBase':
        model = EmoBase()

    model = torch.nn.DataParallel(model).cuda()

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.learning_rate,
                                 weight_decay=args.weight_decay)
    # optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate)
    # custom optimizer
    if args.use_sam:
        base_optim = torch.optim.Adam
        optimizer = SAM(model.parameters(), base_optim, lr=args.learning_rate)
    # custom lr scheduler
    if args.use_cos_wr:
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer, T_0=args.cos_wr_t0, T_mult=args.cos_wr_t_mult)
    elif args.use_cos:
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, args.cos_t_max)
    elif args.use_multistep:
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, args.step_milestones, args.step_decay)
    # SWA
    if args.use_swa:
        swa_model = torch.optim.swa_utils.AveragedModel(model)
        swa_scheduler = torch.optim.swa_utils.SWALR(optimizer,
                                                    swa_lr=args.learning_rate)

    # ckpt structure {epoch, state_dict, optimizer, best_corr}
    if args.resume and os.path.isfile(args.resume):
        print('Load checkpoint:', args.resume)
        ckpt = torch.load(args.resume)
        args.start_epoch = ckpt['epoch']
        best_corr = ckpt['best_corr']
        model.load_state_dict(ckpt['state_dict'])
        optimizer.load_state_dict(ckpt['optimizer'])
        print('Loaded ckpt at epoch:', args.start_epoch)

    # initialize datasets
    train_loader = torch.utils.data.DataLoader(dataset=EEV_Dataset(
        csv_path=args.train_csv,
        vidmap_path=args.train_vidmap,
        image_feat_path=args.image_features,
        audio_feat_path=args.audio_features,
        mode='train',
        lpfilter=args.lp_filter,
        train_freq=args.train_freq,
        val_freq=args.val_freq,
        cls_indices=args.cls_indices),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=False,
                                               drop_last=True)

    val_loader = torch.utils.data.DataLoader(dataset=EEV_Dataset(
        csv_path=args.val_csv,
        vidmap_path=args.val_vidmap,
        image_feat_path=args.image_features,
        audio_feat_path=args.audio_features,
        mode='val',
        train_freq=args.train_freq,
        val_freq=args.val_freq,
        cls_indices=args.cls_indices,
        repeat_sample=args.repeat_sample),
                                             batch_size=None,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=False)

    accuracy = correlation

    if args.val_only:
        print('Run validation ...')
        print('start epoch:', args.start_epoch, 'model:', args.resume)
        validate(val_loader, model, accuracy, args.start_epoch, None, None)
        return

    log_training = open(
        os.path.join(args.root_log, args.store_name, 'log.csv'), 'w')
    with open(os.path.join(args.root_log, args.store_name, 'args.txt'),
              'w') as f:
        f.write(str(args))

    tb_writer = SummaryWriter(
        log_dir=os.path.join(args.root_log, args.store_name))
    for epoch in range(args.start_epoch, args.epochs):
        train(train_loader, model, optimizer, epoch, log_training, tb_writer)
        # do lr scheduling after epoch
        if args.use_swa and epoch >= args.swa_start:
            print('swa stepping...')
            swa_model.update_parameters(model)
            swa_scheduler.step()
        elif args.use_cos_wr or args.use_cos or args.use_multistep:
            scheduler.step()

        if (epoch + 1) > 2 and ((epoch + 1) % args.eval_freq == 0 or
                                (epoch + 1) == args.epochs):
            # validate
            if args.use_swa and epoch >= args.swa_start:
                # validate use swa model
                corr = validate(val_loader, swa_model, accuracy, epoch,
                                log_training, tb_writer)
            else:
                corr = validate(val_loader, model, accuracy, epoch,
                                log_training, tb_writer)
            is_best = corr > best_corr
            best_corr = max(corr, best_corr)
            tb_writer.add_scalar('acc/validate_corr_best', best_corr, epoch)
            output_best = 'Best corr: %.4f\n' % (best_corr)
            print(output_best)
            log_training.write(output_best + '\n')
            log_training.flush()

            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'best_corr': best_corr,
                }, is_best)
    initialize(args, seed=42)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    dataset = Cifar(args.batch_size, args.threads)
    log = Log(log_each=10)
    model = WideResNet(args.depth,
                       args.width_factor,
                       args.dropout,
                       in_channels=3,
                       labels=10).to(device)

    base_optimizer = torch.optim.SGD
    optimizer = SAM(model.parameters(),
                    base_optimizer,
                    rho=args.rho,
                    lr=args.learning_rate,
                    momentum=args.momentum,
                    weight_decay=args.weight_decay)
    scheduler = StepLR(optimizer, args.learning_rate, args.epochs)

    for epoch in range(args.epochs):
        model.train()
        log.train(len_dataset=len(dataset.train))

        for batch in dataset.train:
            inputs, targets = (b.to(device) for b in batch)

            # first forward-backward step
            predictions = model(inputs)
            loss = smooth_crossentropy(predictions, targets)
            loss.mean().backward()
Exemple #7
0
def get_optimizer(model):
    return SAM(filter(lambda p: p.requires_grad, model.parameters()),
               torch.optim.SGD,
               lr=config['base_lr'],
               momentum=config['momentum'],
               weight_decay=config['weight_decay'])
Exemple #8
0
def run(args):
    import torch

    from denoiser import distrib
    from denoiser.data import NoisyCleanSet
    from denoiser.demucs import Demucs
    from denoiser.solver import Solver
    distrib.init(args)

    model = Demucs(**args.demucs)

    if args.show:
        logger.info(model)
        mb = sum(p.numel() for p in model.parameters()) * 4 / 2**20
        logger.info('Size: %.1f MB', mb)
        if hasattr(model, 'valid_length'):
            field = model.valid_length(1)
            logger.info('Field: %.1f ms', field / args.sample_rate * 1000)
        return

    assert args.batch_size % distrib.world_size == 0
    args.batch_size //= distrib.world_size

    length = int(args.segment * args.sample_rate)
    stride = int(args.stride * args.sample_rate)
    # Demucs requires a specific number of samples to avoid 0 padding during training
    if hasattr(model, 'valid_length'):
        length = model.valid_length(length)
    kwargs = {"matching": args.dset.matching, "sample_rate": args.sample_rate}
    # Building datasets and loaders
    tr_dataset = NoisyCleanSet(args.dset.train,
                               length=length,
                               stride=stride,
                               pad=args.pad,
                               **kwargs)
    tr_loader = distrib.loader(tr_dataset,
                               batch_size=args.batch_size,
                               shuffle=True,
                               num_workers=args.num_workers)
    if args.dset.valid:
        cv_dataset = NoisyCleanSet(args.dset.valid, **kwargs)
        cv_loader = distrib.loader(cv_dataset,
                                   batch_size=1,
                                   num_workers=args.num_workers)
    else:
        cv_loader = None
    if args.dset.test:
        tt_dataset = NoisyCleanSet(args.dset.test, **kwargs)
        tt_loader = distrib.loader(tt_dataset,
                                   batch_size=1,
                                   num_workers=args.num_workers)
    else:
        tt_loader = None
    data = {
        "tr_loader": tr_loader,
        "cv_loader": cv_loader,
        "tt_loader": tt_loader
    }

    # torch also initialize cuda seed if available
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        model.cuda()

    # optimizer
    if args.optim == "adam":
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args.lr,
                                     betas=(0.9, args.beta2))
    elif args.optim == "sam":
        #Adding SAM optimizer: https://github.com/davda54/sam
        base_optimizer = torch.optim.Adam(
            model.parameters(), lr=args.lr,
            betas=(0.9, args.beta2
                   ))  # define an optimizer for the "sharpness-aware" update
        optimizer = SAM(model.parameters(),
                        base_optimizer,
                        lr=0.1,
                        momentum=0.9)
    else:
        logger.fatal('Invalid optimizer %s', args.optim)
        os._exit(1)

    # Construct Solver
    solver = Solver(data, model, optimizer, args)
    solver.train()
Exemple #9
0
    def update_weights(self, model, global_round, idx_user):
        # Set mode to train model
        #        model.to(self.device)
        #        model.train()
        epoch_loss = []
        total_norm = []
        loss_list = []
        conv_grad = []
        fc_grad = []
        # Set optimizer for the local updates
        if self.args.optimizer == 'sgd_bench':
            optimizer = torch.optim.SGD(model.parameters(),
                                        lr=self.args.lr,
                                        momentum=0.9)
        elif self.args.optimizer == 'adam':
            optimizer = torch.optim.Adam(model.parameters(),
                                         lr=self.args.lr,
                                         weight_decay=1e-4)
        elif self.args.optimizer == 'sgd_vc':
            optimizer = torch.optim.SGD(model.parameters(),
                                        lr=self.args.lr,
                                        weight_decay=1e-4,
                                        momentum=0.9)
        elif self.args.optimizer == 'sam':
            base_optimizer = torch.optim.SGD  # define an optimizer for the "sharpness-aware" update
            optimizer = SAM(model.parameters(),
                            base_optimizer,
                            lr=self.args.lr,
                            momentum=0.9,
                            weight_decay=1e-4)
        elif self.args.optimizer == 'no_weight_decay':
            optimizer = torch.optim.SGD(model.parameters(), lr=self.args.lr)
        elif self.args.optimizer == 'clip':
            optimizer = torch.optim.SGD(model.parameters(),
                                        lr=self.args.lr,
                                        weight_decay=1e-4)
        elif self.args.optimizer == 'resnet':
            optimizer = torch.optim.SGD(model.parameters(),
                                        lr=self.args.lr,
                                        momentum=0.9,
                                        weight_decay=5e-4)
        elif self.args.optimizer == 'no_momentum':
            optimizer = torch.optim.SGD(model.parameters(),
                                        lr=self.args.lr,
                                        weight_decay=1e-4)
        elif self.args.optimizer == 'clip_nf':
            optimizer = torch.optim.SGD(model.parameters(),
                                        lr=self.args.lr,
                                        momentum=0.9,
                                        weight_decay=5e-4)
            if 'resnet' in self.args.model:
                optimizer = AGC(model.parameters(),
                                optimizer,
                                model=model,
                                ignore_agc=['fc'],
                                clipping=1e-3)
            else:
                optimizer = AGC(model.parameters(),
                                optimizer,
                                model=model,
                                ignore_agc=['fc1', 'fc2', 'fc3'],
                                clipping=1e-3)
            # optimizer = SGD_AGC(model.parameters(), lr=self.args.lr, momentum=0.9, weight_decay=5e-4, clipping=1e-3)

        for iter in range(self.args.local_ep):
            batch_loss = []
            for batch_idx, (images, labels) in enumerate(self.trainloader):
                images, labels = images.to(self.device), labels.to(self.device)

                optimizer.zero_grad()
                log_probs = model(images)
                loss = self.criterion(log_probs, labels)
                if self.args.verbose == 0:
                    del images
                    del labels
                    torch.cuda.empty_cache()

                loss.backward()

                # gradient 확인용 - how does BN
                conv_grad.append(model.conv1.weight.grad.clone().to('cpu'))
                if self.args.optimizer != 'clip':
                    total_norm.append(check_norm(model))

                if self.args.model == 'cnn' or self.args.model == 'cnn_ws':
                    fc_grad.append(model.fc3.weight.grad.clone().to('cpu'))
                else:
                    fc_grad.append(model.fc.weight.grad.clone().to('cpu'))

                if self.args.optimizer == 'sam':
                    optimizer.first_step(zero_grad=True)
                    log_probs = model(images)
                    loss = self.criterion(log_probs, labels)
                    loss.backward()
                    optimizer.second_step(zero_grad=True)
                elif self.args.optimizer == 'clip':
                    max_norm = 0.3
                    if self.args.lr == 5:
                        max_norm = 0.08
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   max_norm)
                    total_norm.append(check_norm(model))
                    optimizer.step()
                else:  # sam이 아닌 경우
                    optimizer.step()
                # print(optimizer.param_groups[0]['lr']) # - lr decay 체크용
                if self.args.verbose:
                    print(
                        '|Client : {} Global Round : {} | Local Epoch : {} | [{}/{} ({:.0f}%)]\tLoss: {:.6f}'
                        .format(idx_user, global_round + 1, iter + 1,
                                batch_idx * len(images),
                                len(self.trainloader.dataset),
                                100. * batch_idx / len(self.trainloader),
                                loss.item()))
                # self.logger.add_scalar('loss', loss.item())
                batch_loss.append(loss.item())
                # itr loss 확인용 - how does BN
                loss_list.append(loss.item())
            print(total_norm)  # gradient 확인용
            epoch_loss.append(sum(batch_loss) / len(batch_loss))

        return model.state_dict(), sum(epoch_loss) / len(
            epoch_loss), loss_list, conv_grad, fc_grad, total_norm
Exemple #10
0
def prepare(args):
    global trainloader
    global testloader
    global net
    global criterion
    global optimizer
    global scheduler

    # Data
    print('==> Preparing data..')
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        CIFAR10Policy(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
        Cutout(n_holes=args.n_holes, length=args.length)
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    trainset = torchvision.datasets.CIFAR10(root='./data',
                                            train=True,
                                            download=True,
                                            transform=transform_train)
    trainloader = torch.utils.data.DataLoader(trainset,
                                              batch_size=512,
                                              shuffle=True,
                                              num_workers=4)

    testset = torchvision.datasets.CIFAR10(root='./data',
                                           train=False,
                                           download=True,
                                           transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset,
                                             batch_size=512,
                                             shuffle=False,
                                             num_workers=4)

    #classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

    # Model
    print('==> Building model..')
    net = CXH()  #CXH_Squeeze_Excitation() #CXH()

    if device == 'cuda':
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = True

    criterion = nn.CrossEntropyLoss()
    # criterion = CrossEntropyLabelSmooth(10)
    # optimizer = optim.SGD(net.parameters(), lr=0.1,
    #                       momentum=0.9, weight_decay=5e-4)

    optimizer = SAM(net.parameters(), torch.optim.SGD, lr=0.1, momentum=0.9)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer, 1, 2)
Exemple #11
0
def main():
    args = get_cli_args()
    validate_cli_args(args)
    alphas = np.array(args.alphas)
    beta = np.array(args.beta)
    mean_prior = np.array([180., 50., 0.])
    Sigma_prior = 1e-12 * np.eye(3, 3)
    initial_state = Gaussian(mean_prior, Sigma_prior)

    if args.input_data_file:
        data = load_data(args.input_data_file)
    elif args.num_steps:
        # Generate data, assuming `--num-steps` was present in the CL args.
        data = generate_input_data(initial_state.mu.T, args.num_steps,
                                   args.num_landmarks_per_side,
                                   args.max_obs_per_time_step, alphas, beta,
                                   args.dt)
    else:
        raise RuntimeError('')

    should_show_plots = True if args.animate else False
    should_write_movie = True if args.movie_file else False
    should_update_plots = True if should_show_plots or should_write_movie else False

    field_map = FieldMap(args.num_landmarks_per_side)

    fig_robot = get_plots_figure(should_show_plots, should_write_movie)
    movie_writer = get_movie_writer(should_write_movie, 'Simulation SLAM',
                                    args.movie_fps, args.plot_pause_len)
    progress_bar = FillingCirclesBar('Simulation Progress', max=data.num_steps)

    # sam object init:
    sam = SAM(initial_state, args)
    mu_traj = np.array([None, None])
    theta = []
    with movie_writer.saving(
            fig_robot, args.movie_file,
            data.num_steps) if should_write_movie else get_dummy_context_mgr():
        for t in range(data.num_steps):
            # for t in range(50):
            # Used as means to include the t-th time-step while plotting.
            tp1 = t + 1

            # Control at the current step.
            u = data.filter.motion_commands[t]
            # Observation at the current step.
            z = data.filter.observations[t]

            # TODO SLAM predict(u)
            mu, Sigma = sam.predict(u)

            # TODO SLAM update
            mu, Sigma = sam.update(u, z)
            mu_traj = np.vstack((mu_traj, mu[:2]))
            theta.append(mu[2])

            progress_bar.next()
            if not should_update_plots:
                continue

            plt.figure(1)
            plt.cla()
            plot_field(field_map, z)
            plot_robot(data.debug.real_robot_path[t])
            plot_observations(data.debug.real_robot_path[t],
                              data.debug.noise_free_observations[t],
                              data.filter.observations[t])

            plt.plot(data.debug.real_robot_path[1:tp1, 0],
                     data.debug.real_robot_path[1:tp1, 1], 'm')
            plt.plot(data.debug.noise_free_robot_path[1:tp1, 0],
                     data.debug.noise_free_robot_path[1:tp1, 1], 'g')

            plt.plot([data.debug.real_robot_path[t, 0]],
                     [data.debug.real_robot_path[t, 1]], '*r')
            plt.plot([data.debug.noise_free_robot_path[t, 0]],
                     [data.debug.noise_free_robot_path[t, 1]], '*g')

            # TODO plot SLAM solution
            # robot filtered trajectory and covariance
            plt.plot(mu_traj[:, 0], mu_traj[:, 1], 'blue')
            plot2dcov(mu[:2], Sigma[:2, :2], color='b', nSigma=3, legend=None)

            plt.figure(2, figsize=(8, 6))
            plt.cla()
            plt.spy(sam.A, marker='o', markersize=5)

            if should_show_plots:
                # Draw all the plots and pause to create an animation effect.
                plt.draw()
                plt.pause(args.plot_pause_len)

            if should_write_movie:
                movie_writer.grab_frame()

    progress_bar.finish()

    plt.show()
Exemple #12
0
def train(model,
          n_epochs,
          learningrate,
          train_loader,
          test_loader,
          use_sam=False):
    # optimizer
    if use_sam:
        optimizer = SAM(filter(lambda p: p.requires_grad, model.parameters()),
                        optim.Adam,
                        lr=learningrate)
    else:
        optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                      model.parameters()),
                               lr=learningrate)
    # scheduler
    #scheduler = StepLR(optimizer, step_size=1, gamma=gamma)
    best_acc = 0
    best_model = None
    for epoch in range(n_epochs):
        epoch_loss = 0
        epoch_accuracy = 0
        model.train()
        for data, label in tqdm(train_loader):
            data = data.to(device)
            label = label.to(device)

            output = model(data)
            loss = criterion(output, label)

            if use_sam:
                #optimizer.zero_grad()
                loss.backward()
                optimizer.first_step(zero_grad=True)

                # second forward-backward pass
                output = model(data)
                loss = criterion(output, label)
                loss.backward()
                optimizer.second_step(zero_grad=True)
            else:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            acc = (output.argmax(dim=1) == label).float().mean()
            epoch_accuracy += acc / len(train_loader)
            epoch_loss += loss / len(train_loader)

        model.eval()
        with torch.no_grad():
            epoch_val_accuracy = 0
            epoch_val_loss = 0
            epoch_Positive = 0
            epoch_Negative = 0
            epoch_TP = 0
            epoch_FP = 0
            epoch_TN = 0
            epoch_FN = 0
            for data, label in tqdm(test_loader):
                data = data.to(device)
                label = label.to(device)

                val_output = model(data)
                val_loss = criterion(val_output, label)

                acc = (val_output.argmax(dim=1) == label).float().mean()
                epoch_val_accuracy += acc / len(test_loader)
                epoch_val_loss += val_loss / len(test_loader)
                c_True_Positive, c_False_Positive, c_True_Negative, c_False_Negative, c_Positive, c_Negative = evaluate(
                    val_output, label)
                epoch_TP += c_True_Positive
                epoch_FP += c_False_Positive
                epoch_TN += c_True_Negative
                epoch_FN += c_False_Negative
                epoch_Positive += c_Positive
                epoch_Negative += c_Negative
            Recall = (epoch_TP) / (epoch_TP + epoch_FN)
            Precision = (epoch_TP) / (epoch_TP + epoch_FP)
            F1 = (2 * (Recall * Precision)) / (Recall + Precision)

        print(
            f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n"
        )
        print(
            "Recall: {Recall:.4f},  Precision: {Precision:.4f}, F1 Score: {F1:.4f}"
        )
        if best_acc < epoch_val_accuracy:
            best_acc = epoch_val_accuracy
            best_model = copy.deepcopy(model.state_dict())
        #scheduler.step()

    if best_model is not None:
        model.load_state_dict(best_model)
        print(f"Best acc:{best_acc}")
        model.eval()
        with torch.no_grad():
            epoch_val_accuracy = 0
            epoch_val_loss = 0
            for data, label in test_loader:
                data = data.to(device)
                label = label.to(device)

                val_output = model(data)
                val_loss = criterion(val_output, label)

                acc = (val_output.argmax(dim=1) == label).float().mean()
                epoch_val_accuracy += acc / len(test_loader)
                epoch_val_loss += val_loss / len(test_loader)

        print(
            f"val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n"
        )
    else:
        print(f"No best model Best acc:{best_acc}")
Exemple #13
0
def main():
    """
    Training and validation.
    """

    global best_bleu4, epochs_since_improvement, checkpoint, start_epoch, fine_tune_encoder, data_name, word_map

    # Read word map
    word_map_file = os.path.join(data_folder, 'WORDMAP_' + data_name + '.json')
    with open(word_map_file, 'r') as j:
        word_map = json.load(j)

    # Initialize / load checkpoint
    if use_sam:
        decoder = DecoderWithAttention(attention_dim=attention_dim,
                                       embed_dim=emb_dim,
                                       decoder_dim=decoder_dim,
                                       vocab_size=len(word_map),
                                       dropout=dropout,
                                       use_glove=use_glove,
                                       word_map=word_map)
        base_optimizer = torch.optim.SGD
        decoder_optimizer = SAM(filter(lambda p: p.requires_grad,
                                       decoder.parameters()),
                                base_optimizer,
                                lr=decoder_lr,
                                momentum=0.9)

        checkpoint = torch.load(checkpoint)
        encoder = checkpoint['encoder']
        encoder_optimizer = None
        print("Loading best encoder but random decoder and using SAM...")

    elif checkpoint is None:
        decoder = DecoderWithAttention(attention_dim=attention_dim,
                                       embed_dim=emb_dim,
                                       decoder_dim=decoder_dim,
                                       vocab_size=len(word_map),
                                       dropout=dropout,
                                       use_glove=use_glove,
                                       word_map=word_map)
        decoder_optimizer = torch.optim.Adam(params=filter(
            lambda p: p.requires_grad, decoder.parameters()),
                                             lr=decoder_lr)
        encoder = Encoder()
        encoder.fine_tune(fine_tune_encoder)
        encoder_optimizer = torch.optim.Adam(
            params=filter(lambda p: p.requires_grad, encoder.parameters()),
            lr=encoder_lr) if fine_tune_encoder else None

    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        print(f"Continuing training from epoch {start_epoch}...")
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        best_bleu4 = checkpoint['bleu-4']
        decoder = checkpoint['decoder']
        if use_sam:
            lr = checkpoint['decoder_optimizer'].param_groups[0]['lr']
            base_optimizer = torch.optim.SGD
            decoder_optimizer = SAM(filter(lambda p: p.requires_grad,
                                           decoder.parameters()),
                                    base_optimizer,
                                    lr=lr,
                                    momentum=0.9)
        else:
            decoder_optimizer = checkpoint['decoder_optimizer']
        encoder = checkpoint['encoder']

        if use_sam and fine_tune_encoder is True:
            lr = checkpoint['encoder_optimizer'].param_groups[0]['lr']
            base_optimizer = torch.optim.SGD
            encoder_optimizer = SAM(filter(lambda p: p.requires_grad,
                                           encoder.parameters()),
                                    base_optimizer,
                                    lr=lr,
                                    momentum=0.9)
        else:
            encoder_optimizer = checkpoint['encoder_optimizer']
        if fine_tune_encoder is True and encoder_optimizer is None:
            encoder.fine_tune(fine_tune_encoder)
            if use_sam:
                base_optimizer = torch.optim.SGD
                encoder_optimizer = SAM(filter(lambda p: p.requires_grad,
                                               encoder.parameters()),
                                        base_optimizer,
                                        lr=encoder_lr,
                                        momentum=0.9)
            else:
                encoder_optimizer = torch.optim.Adam(params=filter(
                    lambda p: p.requires_grad, encoder.parameters()),
                                                     lr=encoder_lr)

    # Move to GPU, if available
    decoder = decoder.to(device)
    encoder = encoder.to(device)

    # Loss function
    criterion = nn.CrossEntropyLoss().to(device)

    # initialize dataloaders
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_loader = torch.utils.data.DataLoader(CocoCaptionDataset(
        data_folder,
        data_name,
        'TRAIN',
        transforms=transforms.Compose([normalize])),
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(CocoCaptionDataset(
        data_folder,
        data_name,
        'VAL',
        transforms=transforms.Compose([normalize])),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=workers,
                                             pin_memory=True)

    print(f"Train dataloader len: {len(train_loader)}")
    print(f"Val dataloader len: {len(val_loader)}")

    # set up tensorbaord
    train_writer = SummaryWriter(
        os.path.join(log_directory, f"{log_name}/train"))
    val_writer = SummaryWriter(os.path.join(log_directory, f"{log_name}/val"))

    # Epochs
    for epoch in tqdm(range(start_epoch, epochs)):

        # Decay learning rate if there is no improvement for 8 consecutive epochs, and terminate training after 20
        if epochs_since_improvement == 20:
            break
        if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:
            adjust_learning_rate(decoder_optimizer, 0.8)
            if fine_tune_encoder:
                adjust_learning_rate(encoder_optimizer, 0.8)

        # One epoch's training
        train(train_loader=train_loader,
              encoder=encoder,
              decoder=decoder,
              criterion=criterion,
              encoder_optimizer=encoder_optimizer,
              decoder_optimizer=decoder_optimizer,
              epoch=epoch,
              train_writer=train_writer)

        # One epoch's validation
        recent_bleu4, val_loss, val_top5_acc = validate(val_loader=val_loader,
                                                        encoder=encoder,
                                                        decoder=decoder,
                                                        criterion=criterion)
        val_writer.add_scalar('Epoch loss', val_loss, epoch + 1)
        val_writer.add_scalar('Epoch top-5 accuracy', val_top5_acc, epoch + 1)
        val_writer.add_scalar('BLEU-4', recent_bleu4, epoch + 1)

        # Check if there was an improvement
        is_best = recent_bleu4 > best_bleu4
        best_bleu4 = max(recent_bleu4, best_bleu4)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" %
                  (epochs_since_improvement, ))
        else:
            epochs_since_improvement = 0

        # Save checkpoint
        checkpoint_name = data_name
        if use_glove:
            checkpoint_name = f"glove_{checkpoint_name}"
        if use_sam:
            checkpoint_name = f"sam_{checkpoint_name}"
        save_checkpoint(checkpoint_name, epoch, epochs_since_improvement,
                        encoder, decoder, encoder_optimizer, decoder_optimizer,
                        recent_bleu4, is_best, checkpoint_path)
Exemple #14
0
def main_train():
    global args, best_corr

    args.store_name = '{}'.format(args.model)
    args.store_name = 'zzd' + args.store_name + datetime.now().strftime(
        '_%m-%d_%H-%M-%S')
    args.start_epoch = 0

    check_rootfolders(args)
    if args.model == 'Baseline':
        model = Baseline()
        model2 = Baseline()
    elif args.model == 'TCFPN':
        model = TCFPN(layers=[48, 64, 96],
                      in_channels=(2048 + 128),
                      num_classes=15,
                      kernel_size=11)

    model = torch.nn.DataParallel(model).cuda()
    model2 = torch.nn.DataParallel(model2).cuda()

    # optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
    args.learning_rate = 0.02
    print('init: %f' % args.learning_rate)
    optimizer = torch.optim.SGD([
        {
            'params': model.parameters(),
            'lr': args.learning_rate
        },
        {
            'params': model2.parameters(),
            'lr': args.learning_rate
        },
    ],
                                weight_decay=1e-4,
                                momentum=0.9,
                                nesterov=True)
    # custom optimizer
    if args.use_sam:
        base_optim = torch.optim.Adam
        optimizer = SAM(model.parameters(), base_optim, lr=args.learning_rate)
    # custom lr scheduler
    #print(args.use_cos_wr)
    #if args.use_cos_wr:
    #args.cos_wr_t0 = 10
    #print('using Restart: %d'%args.cos_wr_t0)
    #scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=args.cos_wr_t0,T_mult=2)
    #elif args.use_cos:
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, args.cos_t_max)
    # SWA
    if args.use_swa:
        swa_model = torch.optim.swa_utils.AveragedModel(model)
        swa_scheduler = torch.optim.swa_utils.SWALR(optimizer,
                                                    swa_lr=args.learning_rate)

    # ckpt structure {epoch, state_dict, optimizer, best_corr}
    if args.resume and os.path.isfile(args.resume):
        print('Load checkpoint:', args.resume)
        ckpt = torch.load(args.resume)
        args.start_epoch = ckpt['epoch']
        best_corr = ckpt['best_corr']
        model.load_state_dict(ckpt['state_dict'])
        optimizer.load_state_dict(ckpt['optimizer'])
        print('Loaded ckpt at epoch:', args.start_epoch)

    # initialize datasets
    train_loader = torch.utils.data.DataLoader(dataset=EEV_Dataset(
        csv_path=args.train_csv,
        vidmap_path=args.train_vidmap,
        image_feat_path=args.image_features,
        audio_feat_path=args.audio_features,
        mode='train',
        lpfilter=args.lp_filter,
        train_freq=args.train_freq,
        val_freq=args.val_freq),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=2,
                                               pin_memory=True,
                                               drop_last=True)
    train_loader2 = torch.utils.data.DataLoader(dataset=EEV_Dataset(
        csv_path=args.train_csv,
        vidmap_path=args.train_vidmap,
        image_feat_path=args.image_features,
        audio_feat_path=args.audio_features,
        mode='train',
        lpfilter=args.lp_filter,
        train_freq=args.train_freq,
        val_freq=args.val_freq),
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                num_workers=2,
                                                pin_memory=True,
                                                drop_last=True)
    train_loader3 = torch.utils.data.DataLoader(dataset=EEV_Dataset(
        csv_path=args.train_csv,
        vidmap_path=args.train_vidmap,
        image_feat_path=args.image_features,
        audio_feat_path=args.audio_features,
        mode='train',
        lpfilter=args.lp_filter,
        train_freq=args.train_freq,
        val_freq=args.val_freq),
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                num_workers=2,
                                                pin_memory=True,
                                                drop_last=True)

    val_loader = torch.utils.data.DataLoader(dataset=EEV_Dataset(
        csv_path=args.val_csv,
        vidmap_path=args.val_vidmap,
        image_feat_path=args.image_features,
        audio_feat_path=args.audio_features,
        mode='val',
        train_freq=args.train_freq,
        val_freq=args.val_freq),
                                             batch_size=None,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=False)

    accuracy = correlation
    log_training = open(
        os.path.join(args.root_log, args.store_name, 'log.csv'), 'w')
    with open(os.path.join(args.root_log, args.store_name, 'args.txt'),
              'w') as f:
        f.write(str(args))

    tb_writer = SummaryWriter(
        log_dir=os.path.join(args.root_log, args.store_name))
    for epoch in range(args.start_epoch, args.epochs):
        train(train_loader, train_loader2, train_loader3, model, model2,
              optimizer, epoch, log_training, tb_writer)
        # do lr scheduling after epoch
        if args.use_swa and epoch >= args.swa_start:
            print('swa stepping...')
            swa_model.update_parameters(model)
            swa_scheduler.step()
        elif args.use_cos_wr:
            print('cos warm restart (T0:{} Tm:{}) stepping...'.format(
                args.cos_wr_t0, args.cos_wr_t_mult))
            scheduler.step()
        elif args.use_cos:
            print('cos (Tmax:{}) stepping...'.format(args.cos_t_max))
            scheduler.step()

        if (epoch + 1) % args.eval_freq == 0 or (epoch + 1) == args.epochs:
            # validate
            if args.use_swa and epoch >= args.swa_start:
                # validate use swa model
                corr = validate(val_loader, swa_model, accuracy, epoch,
                                log_training, tb_writer)
            else:
                corr = validate(val_loader, model, accuracy, epoch,
                                log_training, tb_writer)
            is_best = corr > best_corr
            best_corr = max(corr, best_corr)
            tb_writer.add_scalar('acc/validate_corr_best', best_corr, epoch)
            output_best = 'Best corr: %.4f\n' % (best_corr)
            print(output_best)
            log_training.write(output_best + '\n')
            log_training.flush()

            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'best_corr': best_corr,
                }, is_best)