示例#1
0
def main(config):
    if config.model == 'c3d':
        model, params = C3D(config)
    elif config.model == 'convlstm':
        model, params = ConvLSTM(config)
    elif config.model == 'densenet':
        model, params = densenet(config)
    elif config.model == 'densenet_lean':
        model, params = densenet_lean(config)
    elif config.model == 'resnext':
        model, params = resnext(config)
    else:
        model, params = densenet_lean(config)

    dataset = config.dataset
    sample_size = config.sample_size
    stride = config.stride
    sample_duration = config.sample_duration

    cv = config.num_cv

    # crop_method = GroupRandomScaleCenterCrop(size=sample_size)
    crop_method = MultiScaleRandomCrop(config.scales, config.sample_size[0])
    # norm = Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

    norm = Normalize([114.7748, 107.7354, 99.475], [1, 1, 1])
    # spatial_transform = Compose(
    #     [crop_method,
    #      GroupRandomHorizontalFlip(),
    #      ToTensor(1), norm])
    spatial_transform = Compose([
        RandomHorizontalFlip(), crop_method,
        ToTensor(config.norm_value), norm
    ])
    # temporal_transform = RandomCrop(size=sample_duration, stride=stride)
    temporal_transform = TemporalRandomCrop(config.sample_duration,
                                            config.downsample)
    target_transform = Label()

    train_batch = config.train_batch
    train_data = RWF2000('/content/RWF_2000/frames/',
                         g_path + '/RWF-2000.json', 'training',
                         spatial_transform, temporal_transform,
                         target_transform, dataset)
    train_loader = DataLoader(train_data,
                              batch_size=train_batch,
                              shuffle=True,
                              num_workers=4,
                              pin_memory=True)

    crop_method = GroupScaleCenterCrop(size=sample_size)
    norm = Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    spatial_transform = Compose([crop_method, ToTensor(), norm])
    temporal_transform = CenterCrop(size=sample_duration, stride=stride)
    target_transform = Label()

    val_batch = config.val_batch

    val_data = RWF2000('/content/RWF_2000/frames/', g_path + '/RWF-2000.json',
                       'validation', spatial_transform, temporal_transform,
                       target_transform, dataset)
    val_loader = DataLoader(val_data,
                            batch_size=val_batch,
                            shuffle=False,
                            num_workers=4,
                            pin_memory=True)

    if not os.path.exists('{}/pth'.format(config.output)):
        os.mkdir('{}/pth'.format(config.output))
    if not os.path.exists('{}/log'.format(config.output)):
        os.mkdir('{}/log'.format(config.output))

    batch_log = Log(
        '{}/log/{}_fps{}_{}_batch{}.log'.format(
            config.output,
            config.model,
            sample_duration,
            dataset,
            cv,
        ), ['epoch', 'batch', 'iter', 'loss', 'acc', 'lr'])
    epoch_log = Log(
        '{}/log/{}_fps{}_{}_epoch{}.log'.format(config.output, config.model,
                                                sample_duration, dataset, cv),
        ['epoch', 'loss', 'acc', 'lr'])
    val_log = Log(
        '{}/log/{}_fps{}_{}_val{}.log'.format(config.output, config.model,
                                              sample_duration, dataset, cv),
        ['epoch', 'loss', 'acc'])

    criterion = nn.CrossEntropyLoss().to(device)
    # criterion = nn.BCELoss().to(device)

    learning_rate = config.learning_rate
    momentum = config.momentum
    weight_decay = config.weight_decay

    optimizer = torch.optim.SGD(params=params,
                                lr=learning_rate,
                                momentum=momentum,
                                weight_decay=weight_decay,
                                dampening=False,
                                nesterov=False)

    # optimizer = torch.optim.Adam(params=params, lr = learning_rate, weight_decay= weight_decay)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, verbose=True, factor=config.factor, min_lr=config.min_lr)

    acc_baseline = config.acc_baseline
    loss_baseline = 1

    for p in range(1, config.num_prune):
        if p > 0:
            model = torch.load('{}/pth/prune_{}.pth'.format(
                config.output, p - 1))
        print(f"Prune {p}/{config.num_prune}")
        params = sum([np.prod(p.size()) for p in model.parameters()])
        print("Number of Parameters: %.1fM" % (params / 1e6))
        model = prune_model(model)
        params = sum([np.prod(p.size()) for p in model.parameters()])
        print("Number of Parameters: %.1fM" % (params / 1e6))
        model.to(config.device)
        acc_baseline = 0
        for i in range(5):
            train(i, train_loader, model, criterion, optimizer, device,
                  batch_log, epoch_log)
            val_loss, val_acc = val(i, val_loader, model, criterion, device,
                                    val_log)
            scheduler.step(val_loss)
            if val_acc > acc_baseline or (val_acc >= acc_baseline
                                          and val_loss < loss_baseline):
                # torch.save(
                # model.state_dict(),
                # '{}/pth/prune_{}_{}_fps{}_{}{}_{}_{:.4f}_{:.6f}.pth'.format(
                #     config.output, p, config.model, sample_duration, dataset, cv, i, val_acc,
                #     val_loss))
                torch.save(model,
                           '{}/pth/prune_{}.pth'.format(config.output, p))
def main(config):
    # load model
    if config.model == 'resnet50':
        model, params = VioNet_Resnet(config)
    elif config.model == 'densenet2D':
        model, params = VioNet_Densenet2D(config)
    # dataset
    dataset = config.dataset
    stride = config.stride
    sample_duration = config.sample_duration

    # cross validation phase
    cv = config.num_cv
    input_mode = config.input_mode
    temp_transform = config.temporal_transform
    
    if dataset == "protest":
        train_loader, val_loader = load_protest_dataset(config)
    elif dataset == "hockey":
        train_loader, val_loader = load_hockey_dataset(config)
  

    log_path = getFolder('VioNet_log')
    chk_path = getFolder('VioNet_pth')
    tsb_path = getFolder('VioNet_tensorboard_log')

    log_tsb_dir = tsb_path + '/{}_fps{}_{}_split{}_input({})_tempTransform({})_Info({})'.format(config.model, sample_duration,
                                                dataset, cv, input_mode, temp_transform, config.additional_info)
    for pth in [log_path, chk_path, tsb_path, log_tsb_dir]:
        # make dir
        if not os.path.exists(pth):
            os.mkdir(pth)

    print('tensorboard dir:', log_tsb_dir)                                                
    writer = SummaryWriter(log_tsb_dir)

    # log
    batch_log = Log(
        log_path+'/{}_fps{}_{}_batch{}_input({})_tempTransform({})_Info({}).log.csv'.format(
            config.model,
            sample_duration,
            dataset,
            cv,
            input_mode, temp_transform, config.additional_info
        ), ['epoch', 'batch', 'iter', 'loss', 'acc', 'lr'])
    epoch_log = Log(
        log_path+'/{}_fps{}_{}_epoch{}_input({})_tempTransform({})_Info({}).log.csv'.format(config.model, sample_duration,
                                               dataset, cv, input_mode, temp_transform, config.additional_info),
        ['epoch', 'loss', 'acc', 'lr'])
    val_log = Log(
        log_path+'/{}_fps{}_{}_val{}_input({})_tempTransform({})_Info({}).log.csv'.format(config.model, sample_duration,
                                             dataset, cv, input_mode, temp_transform, config.additional_info),
        ['epoch', 'loss', 'acc'])
    
    train_val_log = Log(log_path+'/{}_fps{}_{}_split{}_input({})_tempTransform({})_Info({}).LOG.csv'.format(config.model, sample_duration,
                                               dataset, cv, input_mode, temp_transform, config.additional_info),
        ['epoch', 'train_loss', 'train_acc', 'lr', 'val_loss', 'val_acc'])

    # prepare
    criterion = nn.CrossEntropyLoss().to(device)

    learning_rate = config.learning_rate
    momentum = config.momentum
    weight_decay = config.weight_decay

    optimizer = torch.optim.SGD(params=params,
                                lr=learning_rate,
                                momentum=momentum,
                                weight_decay=weight_decay)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           verbose=True,
                                                           factor=config.factor,
                                                           min_lr=config.min_lr)

    acc_baseline = config.acc_baseline
    loss_baseline = 1

    # for i, (inputs, targets) in enumerate(val_loader):
    #     print('inputs:', inputs.size())

    for i in range(config.num_epoch):
        train_loss, train_acc, lr = train(i, train_loader, model, criterion, optimizer, device, batch_log,
              epoch_log)
        val_loss, val_acc = val(i, val_loader, model, criterion, device,
                                val_log)
        epoch = i+1
        train_val_log.log({'epoch': epoch, 'train_loss': train_loss, 'train_acc': train_acc, 'lr': lr, 'val_loss': val_loss, 'val_acc': val_acc})
        writer.add_scalar('training loss',
                            train_loss,
                            epoch)
        writer.add_scalar('training accuracy',
                            train_acc,
                            epoch)
        writer.add_scalar('validation loss',
                            val_loss,
                            epoch)
        writer.add_scalar('validation accuracy',
                            val_acc,
                            epoch)

        scheduler.step(val_loss)
        if val_acc > acc_baseline or (val_acc >= acc_baseline and
                                      val_loss < loss_baseline):
            torch.save(
                model.state_dict(),
                chk_path+'/{}_fps{}_{}{}_{}_{:.4f}_{:.6f}.pth'.format(
                    config.model, sample_duration, dataset, cv, i, val_acc,
                    val_loss))
            acc_baseline = val_acc
            loss_baseline = val_loss
def main(config: Config, root, annotation_path):
    if config.model == 'resnet50':
        model, params = VioNet_Resnet(config)
    elif config.model == 'densenet2D':
        model, params = VioNet_Densenet2D(config)
    elif config.model == 'resnetXT':
        model, params = VioNet_ResnetXT(config)
    log_path = getFolder('VioNet_log')
    chk_path = getFolder('VioNet_pth')
    tsb_path = getFolder('VioNet_tensorboard_log')

    log_tsb_dir = tsb_path + '/{}_fps{}_{}_split{}_input({})_Info({})'.format(
        config.model, config.sample_duration, config.dataset, config.num_cv,
        config.input_mode, config.additional_info)
    for pth in [log_path, chk_path, tsb_path, log_tsb_dir]:
        # make dir
        if not os.path.exists(pth):
            os.mkdir(pth)

    print('tensorboard dir:', log_tsb_dir)
    writer = SummaryWriter(log_tsb_dir)

    # log
    batch_log = Log(
        log_path + '/{}_fps{}_{}_batch{}_input({})_Info({}).log.csv'.format(
            config.model, config.sample_duration, config.dataset,
            config.num_cv, config.input_mode, config.additional_info),
        ['epoch', 'batch', 'iter', 'loss', 'acc', 'lr'])
    epoch_log = Log(
        log_path + '/{}_fps{}_{}_epoch{}_input({})_Info({}).log.csv'.format(
            config.model, config.sample_duration, config.dataset,
            config.num_cv, config.input_mode, config.additional_info),
        ['epoch', 'loss', 'acc', 'lr'])
    val_log = Log(
        log_path + '/{}_fps{}_{}_val{}_input({})_Info({}).log.csv'.format(
            config.model, config.sample_duration, config.dataset,
            config.num_cv, config.input_mode, config.additional_info),
        ['epoch', 'loss', 'acc'])

    train_val_log = Log(
        log_path + '/{}_fps{}_{}_split{}_input({})_Info({}).LOG.csv'.format(
            config.model, config.sample_duration, config.dataset,
            config.num_cv, config.input_mode, config.additional_info),
        ['epoch', 'train_loss', 'train_acc', 'lr', 'val_loss', 'val_acc'])

    criterion = nn.CrossEntropyLoss().to(device)

    learning_rate = config.learning_rate
    momentum = config.momentum
    weight_decay = config.weight_decay

    optimizer = torch.optim.SGD(params=params,
                                lr=learning_rate,
                                momentum=momentum,
                                weight_decay=weight_decay)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, verbose=True, factor=config.factor, min_lr=config.min_lr)

    acc_baseline = config.acc_baseline
    loss_baseline = 1

    train_loader, val_loader = laod_HMDB51_frames_dataset(
        config, root, annotation_path)

    for i in range(config.num_epoch):
        train_loss, train_acc, lr = train(i, train_loader, model, criterion,
                                          optimizer, device, batch_log,
                                          epoch_log)
        val_loss, val_acc = val(i, val_loader, model, criterion, device,
                                val_log)
        epoch = i + 1
        train_val_log.log({
            'epoch': epoch,
            'train_loss': train_loss,
            'train_acc': train_acc,
            'lr': lr,
            'val_loss': val_loss,
            'val_acc': val_acc
        })
        writer.add_scalar('training loss', train_loss, epoch)
        writer.add_scalar('training accuracy', train_acc, epoch)
        writer.add_scalar('validation loss', val_loss, epoch)
        writer.add_scalar('validation accuracy', val_acc, epoch)

        scheduler.step(val_loss)
        if val_acc > acc_baseline or (val_acc >= acc_baseline
                                      and val_loss < loss_baseline):
            torch.save(
                model.state_dict(), chk_path +
                '/{}_fps{}_{}{}_{}_{:.4f}_{:.6f}_config({}).pth'.format(
                    config.model, config.sample_duration, config.dataset,
                    config.num_cv, i, val_acc, val_loss,
                    config.additional_info))
            acc_baseline = val_acc
            loss_baseline = val_loss
示例#4
0
文件: main.py 项目: poem2018/AVSS2019
def main(config):
    # load model
    if config.model == 'c3d':
        model, params = VioNet_C3D(config)
    elif config.model == 'convlstm':
        model, params = VioNet_ConvLSTM(config)
    elif config.model == 'densenet':
        model, params = VioNet_densenet(config)
    elif config.model == 'densenet_lean':
        model, params = VioNet_densenet_lean(config)
    # default densenet
    else:
        model, params = VioNet_densenet_lean(config)

    # dataset
    dataset = config.dataset
    sample_size = config.sample_size
    stride = config.stride
    sample_duration = config.sample_duration

    # cross validation phase
    cv = config.num_cv

    # train set
    crop_method = GroupRandomScaleCenterCrop(size=sample_size)
    norm = Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    spatial_transform = Compose(
        [crop_method,
         GroupRandomHorizontalFlip(),
         ToTensor(), norm])
    temporal_transform = RandomCrop(size=sample_duration, stride=stride)
    target_transform = Label()

    train_batch = config.train_batch

    train_data = VioDB('../VioDB/{}_jpg/'.format(dataset),
                       '../VioDB/{}_jpg{}.json'.format(dataset,
                                                       cv), 'training',
                       spatial_transform, temporal_transform, target_transform)
    train_loader = DataLoader(train_data,
                              batch_size=train_batch,
                              shuffle=True,
                              num_workers=4,
                              pin_memory=True)

    # val set
    crop_method = GroupScaleCenterCrop(size=sample_size)
    norm = Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    spatial_transform = Compose([crop_method, ToTensor(), norm])
    temporal_transform = CenterCrop(size=sample_duration, stride=stride)
    target_transform = Label()

    val_batch = config.val_batch

    val_data = VioDB('../VioDB/{}_jpg/'.format(dataset),
                     '../VioDB/{}_jpg{}.json'.format(dataset,
                                                     cv), 'validation',
                     spatial_transform, temporal_transform, target_transform)
    val_loader = DataLoader(val_data,
                            batch_size=val_batch,
                            shuffle=False,
                            num_workers=4,
                            pin_memory=True)

    # make dir
    if not os.path.exists('./pth'):
        os.mkdir('./pth')
    if not os.path.exists('./log'):
        os.mkdir('./log')

    # log
    batch_log = Log(
        './log/{}_fps{}_{}_batch{}.log'.format(
            config.model,
            sample_duration,
            dataset,
            cv,
        ), ['epoch', 'batch', 'iter', 'loss', 'acc', 'lr'])
    epoch_log = Log(
        './log/{}_fps{}_{}_epoch{}.log'.format(config.model, sample_duration,
                                               dataset, cv),
        ['epoch', 'loss', 'acc', 'lr'])
    val_log = Log(
        './log/{}_fps{}_{}_val{}.log'.format(config.model, sample_duration,
                                             dataset, cv),
        ['epoch', 'loss', 'acc'])

    # prepare
    criterion = nn.CrossEntropyLoss().to(device)

    learning_rate = config.learning_rate
    momentum = config.momentum
    weight_decay = config.weight_decay

    optimizer = torch.optim.SGD(params=params,
                                lr=learning_rate,
                                momentum=momentum,
                                weight_decay=weight_decay)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, verbose=True, factor=config.factor, min_lr=config.min_lr)

    acc_baseline = config.acc_baseline
    loss_baseline = 1

    for i in range(config.num_epoch):
        train(i, train_loader, model, criterion, optimizer, device, batch_log,
              epoch_log)
        val_loss, val_acc = val(i, val_loader, model, criterion, device,
                                val_log)
        scheduler.step(val_loss)
        if val_acc > acc_baseline or (val_acc >= acc_baseline
                                      and val_loss < loss_baseline):
            torch.save(
                model.state_dict(),
                './pth/{}_fps{}_{}{}_{}_{:.4f}_{:.6f}.pth'.format(
                    config.model, sample_duration, dataset, cv, i, val_acc,
                    val_loss))
            acc_baseline = val_acc
            loss_baseline = val_loss
示例#5
0
def main(config, home_path):
    # load model
    if config.model == 'c3d':
        model, params = VioNet_C3D(config, home_path)
    elif config.model == 'convlstm':
        model, params = VioNet_ConvLSTM(config)
    elif config.model == 'densenet':
        model, params = VioNet_densenet(config, home_path)
    elif config.model == 'densenet_lean':
        model, params = VioNet_densenet_lean(config, home_path)
    elif config.model == 'resnet50':
        model, params = VioNet_Resnet(config, home_path)
    elif config.model == 'densenet2D':
        model, params = VioNet_Densenet2D(config)
    elif config.model == 'i3d':
        model, params = VioNet_I3D(config)
    elif config.model == 's3d':
        model, params = VioNet_S3D(config)
    else:
        model, params = VioNet_densenet_lean(config, home_path)

    # dataset
    dataset = config.dataset
    # sample_size = config.sample_size
    stride = config.stride
    sample_duration = config.sample_duration

    # cross validation phase
    cv = config.num_cv
    input_mode = config.input_type

    sample_size, norm = build_transforms_parameters(model_type=config.model)

    # train set
    crop_method = GroupRandomScaleCenterCrop(size=sample_size)

    # if input_mode == 'rgb':
    #     norm = Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

    # elif input_mode == 'dynamic-images':
    #     # norm = Normalize([0.49778724, 0.49780366, 0.49776983], [0.09050678, 0.09017131, 0.0898702 ])
    #     norm = Normalize([38.756858/255, 3.88248729/255, 40.02898126/255], [110.6366688/255, 103.16065604/255, 96.29023126/255])
    # else:
    #     norm = Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

    train_temporal_transform = build_temporal_transformation(
        config, config.train_temporal_transform)
    spatial_transform = Compose(
        [crop_method,
         GroupRandomHorizontalFlip(),
         ToTensor(), norm])
    target_transform = Label()

    train_batch = config.train_batch
    if dataset == RWF_DATASET:
        # train_data = VioDB(g_path + '/VioDB/{}_jpg/frames/'.format(dataset),
        #                 g_path + '/VioDB/{}_jpg{}.json'.format(dataset, cv), 'training',
        #                 spatial_transform, temporal_transform, target_transform, dataset,
        #                 tmp_annotation_path=os.path.join(g_path, config.temp_annotation_path))
        train_data = VioDB(
            os.path.join(home_path, RWF_DATASET.upper(), 'frames/'),
            os.path.join(home_path, VIO_DB_DATASETS, "rwf-2000_jpg1.json"),
            'training',
            spatial_transform,
            train_temporal_transform,
            target_transform,
            dataset,
            tmp_annotation_path=os.path.join(g_path,
                                             config.temp_annotation_path),
            input_type=config.input_type)
    else:
        train_data = VioDB(
            os.path.join(home_path, VIO_DB_DATASETS, dataset +
                         '_jpg'),  #g_path + '/VioDB/{}_jpg/'.format(dataset),
            os.path.join(home_path, VIO_DB_DATASETS, '{}_jpg{}.json'.format(
                dataset,
                cv)),  #g_path + '/VioDB/{}_jpg{}.json'.format(dataset, cv),
            'training',
            spatial_transform,
            train_temporal_transform,
            target_transform,
            dataset,
            tmp_annotation_path=os.path.join(g_path,
                                             config.temp_annotation_path),
            input_type=config.input_type)
    train_loader = DataLoader(train_data,
                              batch_size=train_batch,
                              shuffle=True,
                              num_workers=0,
                              pin_memory=True)

    # val set
    crop_method = GroupScaleCenterCrop(size=sample_size)
    # if input_mode == 'rgb':
    #     norm = Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

    # elif input_mode == 'dynamic-images':
    #     norm = Normalize([0.49778724, 0.49780366, 0.49776983], [0.09050678, 0.09017131, 0.0898702 ])

    val_temporal_transform = build_temporal_transformation(
        config, config.val_temporal_transform)
    spatial_transform = Compose([crop_method, ToTensor(), norm])
    target_transform = Label()

    val_batch = config.val_batch

    if dataset == RWF_DATASET:
        # val_data = VioDB(g_path + '/VioDB/{}_jpg/frames/'.format(dataset),
        #                 g_path + '/VioDB/{}_jpg{}.json'.format(dataset, cv), 'validation',
        #                 spatial_transform, temporal_transform, target_transform, dataset,
        #                 tmp_annotation_path=os.path.join(g_path, config.temp_annotation_path))
        val_data = VioDB(
            os.path.join(home_path, RWF_DATASET.upper(), 'frames/'),
            os.path.join(home_path, VIO_DB_DATASETS, "rwf-2000_jpg1.json"),
            'validation',
            spatial_transform,
            val_temporal_transform,
            target_transform,
            dataset,
            tmp_annotation_path=os.path.join(g_path,
                                             config.temp_annotation_path),
            input_type=config.input_type)
    else:
        val_data = VioDB(
            os.path.join(home_path, VIO_DB_DATASETS, dataset +
                         '_jpg'),  #g_path + '/VioDB/{}_jpg/'.format(dataset),
            os.path.join(home_path, VIO_DB_DATASETS, '{}_jpg{}.json'.format(
                dataset,
                cv)),  #g_path + '/VioDB/{}_jpg{}.json'.format(dataset, cv),
            'validation',
            spatial_transform,
            val_temporal_transform,
            target_transform,
            dataset,
            tmp_annotation_path=os.path.join(g_path,
                                             config.temp_annotation_path),
            input_type=config.input_type)
    val_loader = DataLoader(val_data,
                            batch_size=val_batch,
                            shuffle=False,
                            num_workers=4,
                            pin_memory=True)

    template = '{}_fps{}_{}_split({})_input({})_TmpTransform({})_Info({})'.format(
        config.model, sample_duration, dataset, cv, input_mode,
        config.train_temporal_transform, config.additional_info)
    log_path = os.path.join(home_path, PATH_LOG, template)
    # chk_path = os.path.join(PATH_CHECKPOINT, template)
    tsb_path = os.path.join(home_path, PATH_TENSORBOARD, template)

    for pth in [log_path, tsb_path]:
        if not os.path.exists(pth):
            os.mkdir(pth)

    print('tensorboard dir:', tsb_path)
    writer = SummaryWriter(tsb_path)

    # log
    batch_log = Log(log_path + '/batch_log.csv',
                    ['epoch', 'batch', 'iter', 'loss', 'acc', 'lr'])
    epoch_log = Log(log_path + '/epoch_log.csv',
                    ['epoch', 'loss', 'acc', 'lr'])
    val_log = Log(log_path + '/val_log.csv', ['epoch', 'loss', 'acc'])
    train_val_log = Log(
        log_path + '/train_val_LOG.csv',
        ['epoch', 'train_loss', 'train_acc', 'lr', 'val_loss', 'val_acc'])

    # prepare
    criterion = nn.CrossEntropyLoss().to(device)
    learning_rate = config.learning_rate
    momentum = config.momentum
    weight_decay = config.weight_decay

    optimizer = torch.optim.SGD(params=params,
                                lr=learning_rate,
                                momentum=momentum,
                                weight_decay=weight_decay)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, verbose=True, factor=config.factor, min_lr=config.min_lr)

    acc_baseline = config.acc_baseline
    loss_baseline = 1

    for i in range(config.num_epoch):
        train_loss, train_acc, lr = train(i, train_loader, model, criterion,
                                          optimizer, device, batch_log,
                                          epoch_log)
        val_loss, val_acc = val(i, val_loader, model, criterion, device,
                                val_log)
        epoch = i + 1
        train_val_log.log({
            'epoch': epoch,
            'train_loss': train_loss,
            'train_acc': train_acc,
            'lr': lr,
            'val_loss': val_loss,
            'val_acc': val_acc
        })
        writer.add_scalar('training loss', train_loss, epoch)
        writer.add_scalar('training accuracy', train_acc, epoch)
        writer.add_scalar('validation loss', val_loss, epoch)
        writer.add_scalar('validation accuracy', val_acc, epoch)

        scheduler.step(val_loss)
        if val_acc > acc_baseline or (val_acc >= acc_baseline
                                      and val_loss < loss_baseline):
            torch.save(
                model.state_dict(),
                os.path.join(
                    home_path, PATH_CHECKPOINT,
                    '{}_fps{}_{}{}_{}_{:.4f}_{:.6f}.pth'.format(
                        config.model, sample_duration, dataset, cv, epoch,
                        val_acc, val_loss)))
            acc_baseline = val_acc
            loss_baseline = val_loss