示例#1
0
def main_train():
    global args, best_corr

    args.store_name = '{}'.format(args.model)
    args.store_name = args.store_name + datetime.now().strftime(
        '_%m-%d_%H-%M-%S')
    args.start_epoch = 0

    if not args.val_only:
        check_rootfolders(args)
    if args.model == 'Baseline':
        if args.cls_indices:
            model = Baseline(args.img_feat_size,
                             args.au_feat_size,
                             num_classes=len(args.cls_indices))
        else:
            print('Feature size:', args.img_feat_size, args.au_feat_size)
            model = Baseline(args.img_feat_size, args.au_feat_size)
    elif args.model == 'TCFPN':
        model = TCFPN(layers=[48, 64, 96],
                      in_channels=(128),
                      num_classes=15,
                      kernel_size=11)
    elif args.model == 'BaseAu':
        model = Baseline_Au(args.au_feat_size)
    elif args.model == 'BaseImg':
        model = Baseline_Img(args.img_feat_size)
    elif args.model == 'EmoBase':
        model = EmoBase()

    model = torch.nn.DataParallel(model).cuda()

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.learning_rate,
                                 weight_decay=args.weight_decay)
    # optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate)
    # custom optimizer
    if args.use_sam:
        base_optim = torch.optim.Adam
        optimizer = SAM(model.parameters(), base_optim, lr=args.learning_rate)
    # custom lr scheduler
    if args.use_cos_wr:
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer, T_0=args.cos_wr_t0, T_mult=args.cos_wr_t_mult)
    elif args.use_cos:
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, args.cos_t_max)
    elif args.use_multistep:
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, args.step_milestones, args.step_decay)
    # SWA
    if args.use_swa:
        swa_model = torch.optim.swa_utils.AveragedModel(model)
        swa_scheduler = torch.optim.swa_utils.SWALR(optimizer,
                                                    swa_lr=args.learning_rate)

    # ckpt structure {epoch, state_dict, optimizer, best_corr}
    if args.resume and os.path.isfile(args.resume):
        print('Load checkpoint:', args.resume)
        ckpt = torch.load(args.resume)
        args.start_epoch = ckpt['epoch']
        best_corr = ckpt['best_corr']
        model.load_state_dict(ckpt['state_dict'])
        optimizer.load_state_dict(ckpt['optimizer'])
        print('Loaded ckpt at epoch:', args.start_epoch)

    # initialize datasets
    train_loader = torch.utils.data.DataLoader(dataset=EEV_Dataset(
        csv_path=args.train_csv,
        vidmap_path=args.train_vidmap,
        image_feat_path=args.image_features,
        audio_feat_path=args.audio_features,
        mode='train',
        lpfilter=args.lp_filter,
        train_freq=args.train_freq,
        val_freq=args.val_freq,
        cls_indices=args.cls_indices),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=False,
                                               drop_last=True)

    val_loader = torch.utils.data.DataLoader(dataset=EEV_Dataset(
        csv_path=args.val_csv,
        vidmap_path=args.val_vidmap,
        image_feat_path=args.image_features,
        audio_feat_path=args.audio_features,
        mode='val',
        train_freq=args.train_freq,
        val_freq=args.val_freq,
        cls_indices=args.cls_indices,
        repeat_sample=args.repeat_sample),
                                             batch_size=None,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=False)

    accuracy = correlation

    if args.val_only:
        print('Run validation ...')
        print('start epoch:', args.start_epoch, 'model:', args.resume)
        validate(val_loader, model, accuracy, args.start_epoch, None, None)
        return

    log_training = open(
        os.path.join(args.root_log, args.store_name, 'log.csv'), 'w')
    with open(os.path.join(args.root_log, args.store_name, 'args.txt'),
              'w') as f:
        f.write(str(args))

    tb_writer = SummaryWriter(
        log_dir=os.path.join(args.root_log, args.store_name))
    for epoch in range(args.start_epoch, args.epochs):
        train(train_loader, model, optimizer, epoch, log_training, tb_writer)
        # do lr scheduling after epoch
        if args.use_swa and epoch >= args.swa_start:
            print('swa stepping...')
            swa_model.update_parameters(model)
            swa_scheduler.step()
        elif args.use_cos_wr or args.use_cos or args.use_multistep:
            scheduler.step()

        if (epoch + 1) > 2 and ((epoch + 1) % args.eval_freq == 0 or
                                (epoch + 1) == args.epochs):
            # validate
            if args.use_swa and epoch >= args.swa_start:
                # validate use swa model
                corr = validate(val_loader, swa_model, accuracy, epoch,
                                log_training, tb_writer)
            else:
                corr = validate(val_loader, model, accuracy, epoch,
                                log_training, tb_writer)
            is_best = corr > best_corr
            best_corr = max(corr, best_corr)
            tb_writer.add_scalar('acc/validate_corr_best', best_corr, epoch)
            output_best = 'Best corr: %.4f\n' % (best_corr)
            print(output_best)
            log_training.write(output_best + '\n')
            log_training.flush()

            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'best_corr': best_corr,
                }, is_best)
示例#2
0
def main_train(config, checkpoint_dir=None):
    global args, best_corr
    best_corr = 0.0

    args.store_name = '{}'.format(args.model)
    args.store_name = args.store_name + datetime.now().strftime('_%m-%d_%H-%M-%S')
    args.start_epoch = 0

    # check_rootfolders(args)
    if args.model == 'Baseline':
        model = Baseline()
    elif args.model == 'TCFPN':
        model = TCFPN(layers=[48, 64, 96], in_channels=(2048 + 128), num_classes=15, kernel_size=11)
    
    model = torch.nn.DataParallel(model).cuda()

    if config['optimizer'] == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
    elif config['optimizer'] == 'adamw':
        optimizer = torch.optim.AdamW(model.parameters(), lr=config['lr'])
    
    # custom optimizer
    if args.use_sam:
        base_optim = torch.optim.Adam
        optimizer = SAM(model.parameters(), base_optim, lr=config['lr'])
    # custom lr scheduler
    if args.use_cos_wr:
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=args.cos_wr_t0,T_mult=args.cos_wr_t_mult)
    elif args.use_cos:
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.cos_t_max)
    # SWA
    if args.use_swa:
        swa_model = torch.optim.swa_utils.AveragedModel(model)
        swa_scheduler = torch.optim.swa_utils.SWALR(optimizer, swa_lr=config['lr'])

    # ckpt structure {epoch, state_dict, optimizer, best_corr}
    # if args.resume and os.path.isfile(args.resume):
    #     print('Load checkpoint:', args.resume)
    #     ckpt = torch.load(args.resume)
    #     args.start_epoch = ckpt['epoch']
    #     best_corr = ckpt['best_corr']
    #     model.load_state_dict(ckpt['state_dict'])
    #     optimizer.load_state_dict(ckpt['optimizer'])
    #     print('Loaded ckpt at epoch:', args.start_epoch)
    if checkpoint_dir:
        model_state, optimizer_state = torch.load(
            os.path.join(checkpoint_dir, "checkpoint"))
        model.load_state_dict(model_state)
        optimizer.load_state_dict(optimizer_state)


    # initialize datasets
    train_loader = torch.utils.data.DataLoader(
        dataset=EEV_Dataset(
            csv_path=args.train_csv,
            vidmap_path=args.train_vidmap,
            image_feat_path=args.image_features,
            audio_feat_path=args.audio_features,
            mode='train', lpfilter=args.lp_filter
        ),
        batch_size=config['batch_size'], shuffle=True,
        num_workers=args.workers, pin_memory=False,
        drop_last=True
    )

    val_loader = torch.utils.data.DataLoader(
        dataset=EEV_Dataset(
            csv_path=args.val_csv,
            vidmap_path=args.val_vidmap,
            image_feat_path=args.image_features,
            audio_feat_path=args.audio_features,
            mode='val'
        ),
        batch_size=None, shuffle=False,
        num_workers=args.workers, pin_memory=False
    )

    accuracy = correlation
    # with open(os.path.join(args.root_log, args.store_name, 'args.txt'), 'w') as f:
    #     f.write(str(args))
    
    # tb_writer = SummaryWriter(log_dir=os.path.join(args.root_log, args.store_name))

    for epoch in range(args.start_epoch, args.epochs):
        # train
        train(train_loader, model, optimizer, epoch, None, None)
        # do lr scheduling after epoch
        if args.use_swa and epoch >= args.swa_start:
            print('swa stepping...')
            swa_model.update_parameters(model)
            swa_scheduler.step()
        elif args.use_cos_wr:
            print('cos warm restart (T0:{} Tm:{}) stepping...'.format(args.cos_wr_t0, args.cos_wr_t_mult))
            scheduler.step()
        elif args.use_cos:
            print('cos (Tmax:{}) stepping...'.format(args.cos_t_max))
            scheduler.step()
        
        # validate
        if args.use_swa and epoch >= args.swa_start:
            # validate use swa model
            corr, loss = validate(val_loader, swa_model, accuracy, epoch, None, None)
        else:
            corr, loss = validate(val_loader, model, accuracy, epoch, None, None)
        is_best = corr > best_corr
        best_corr = max(corr, best_corr)
        # tb_writer.add_scalar('acc/validate_corr_best', best_corr, epoch)
        # output_best = 'Best corr: %.4f\n' % (best_corr)
        # print(output_best)
        # save_checkpoint({
        #     'epoch': epoch + 1,
        #     'state_dict': model.state_dict(),
        #     'optimizer': optimizer.state_dict(),
        #     'best_corr': best_corr,
        # }, is_best)
        with tune.checkpoint_dir(epoch) as checkpoint_dir:
            path = os.path.join(checkpoint_dir, "checkpoint")
            if is_best:
                path = os.path.join(checkpoint_dir, "checkpoint_best")
            torch.save((model.state_dict(), optimizer.state_dict()), path)
        tune.report(loss=loss, accuracy=corr, best_corr=best_corr)
示例#3
0
def main_train():
    global args, best_corr

    args.store_name = '{}'.format(args.model)
    args.store_name = 'zzd' + args.store_name + datetime.now().strftime(
        '_%m-%d_%H-%M-%S')
    args.start_epoch = 0

    check_rootfolders(args)
    if args.model == 'Baseline':
        model = Baseline()
        model2 = Baseline()
    elif args.model == 'TCFPN':
        model = TCFPN(layers=[48, 64, 96],
                      in_channels=(2048 + 128),
                      num_classes=15,
                      kernel_size=11)

    model = torch.nn.DataParallel(model).cuda()
    model2 = torch.nn.DataParallel(model2).cuda()

    # optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
    args.learning_rate = 0.02
    print('init: %f' % args.learning_rate)
    optimizer = torch.optim.SGD([
        {
            'params': model.parameters(),
            'lr': args.learning_rate
        },
        {
            'params': model2.parameters(),
            'lr': args.learning_rate
        },
    ],
                                weight_decay=1e-4,
                                momentum=0.9,
                                nesterov=True)
    # custom optimizer
    if args.use_sam:
        base_optim = torch.optim.Adam
        optimizer = SAM(model.parameters(), base_optim, lr=args.learning_rate)
    # custom lr scheduler
    #print(args.use_cos_wr)
    #if args.use_cos_wr:
    #args.cos_wr_t0 = 10
    #print('using Restart: %d'%args.cos_wr_t0)
    #scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=args.cos_wr_t0,T_mult=2)
    #elif args.use_cos:
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, args.cos_t_max)
    # SWA
    if args.use_swa:
        swa_model = torch.optim.swa_utils.AveragedModel(model)
        swa_scheduler = torch.optim.swa_utils.SWALR(optimizer,
                                                    swa_lr=args.learning_rate)

    # ckpt structure {epoch, state_dict, optimizer, best_corr}
    if args.resume and os.path.isfile(args.resume):
        print('Load checkpoint:', args.resume)
        ckpt = torch.load(args.resume)
        args.start_epoch = ckpt['epoch']
        best_corr = ckpt['best_corr']
        model.load_state_dict(ckpt['state_dict'])
        optimizer.load_state_dict(ckpt['optimizer'])
        print('Loaded ckpt at epoch:', args.start_epoch)

    # initialize datasets
    train_loader = torch.utils.data.DataLoader(dataset=EEV_Dataset(
        csv_path=args.train_csv,
        vidmap_path=args.train_vidmap,
        image_feat_path=args.image_features,
        audio_feat_path=args.audio_features,
        mode='train',
        lpfilter=args.lp_filter,
        train_freq=args.train_freq,
        val_freq=args.val_freq),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=2,
                                               pin_memory=True,
                                               drop_last=True)
    train_loader2 = torch.utils.data.DataLoader(dataset=EEV_Dataset(
        csv_path=args.train_csv,
        vidmap_path=args.train_vidmap,
        image_feat_path=args.image_features,
        audio_feat_path=args.audio_features,
        mode='train',
        lpfilter=args.lp_filter,
        train_freq=args.train_freq,
        val_freq=args.val_freq),
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                num_workers=2,
                                                pin_memory=True,
                                                drop_last=True)
    train_loader3 = torch.utils.data.DataLoader(dataset=EEV_Dataset(
        csv_path=args.train_csv,
        vidmap_path=args.train_vidmap,
        image_feat_path=args.image_features,
        audio_feat_path=args.audio_features,
        mode='train',
        lpfilter=args.lp_filter,
        train_freq=args.train_freq,
        val_freq=args.val_freq),
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                num_workers=2,
                                                pin_memory=True,
                                                drop_last=True)

    val_loader = torch.utils.data.DataLoader(dataset=EEV_Dataset(
        csv_path=args.val_csv,
        vidmap_path=args.val_vidmap,
        image_feat_path=args.image_features,
        audio_feat_path=args.audio_features,
        mode='val',
        train_freq=args.train_freq,
        val_freq=args.val_freq),
                                             batch_size=None,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=False)

    accuracy = correlation
    log_training = open(
        os.path.join(args.root_log, args.store_name, 'log.csv'), 'w')
    with open(os.path.join(args.root_log, args.store_name, 'args.txt'),
              'w') as f:
        f.write(str(args))

    tb_writer = SummaryWriter(
        log_dir=os.path.join(args.root_log, args.store_name))
    for epoch in range(args.start_epoch, args.epochs):
        train(train_loader, train_loader2, train_loader3, model, model2,
              optimizer, epoch, log_training, tb_writer)
        # do lr scheduling after epoch
        if args.use_swa and epoch >= args.swa_start:
            print('swa stepping...')
            swa_model.update_parameters(model)
            swa_scheduler.step()
        elif args.use_cos_wr:
            print('cos warm restart (T0:{} Tm:{}) stepping...'.format(
                args.cos_wr_t0, args.cos_wr_t_mult))
            scheduler.step()
        elif args.use_cos:
            print('cos (Tmax:{}) stepping...'.format(args.cos_t_max))
            scheduler.step()

        if (epoch + 1) % args.eval_freq == 0 or (epoch + 1) == args.epochs:
            # validate
            if args.use_swa and epoch >= args.swa_start:
                # validate use swa model
                corr = validate(val_loader, swa_model, accuracy, epoch,
                                log_training, tb_writer)
            else:
                corr = validate(val_loader, model, accuracy, epoch,
                                log_training, tb_writer)
            is_best = corr > best_corr
            best_corr = max(corr, best_corr)
            tb_writer.add_scalar('acc/validate_corr_best', best_corr, epoch)
            output_best = 'Best corr: %.4f\n' % (best_corr)
            print(output_best)
            log_training.write(output_best + '\n')
            log_training.flush()

            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'best_corr': best_corr,
                }, is_best)