示例#1
0
def main():
    parser = argparse.ArgumentParser(description='AID_PRETRAIN')
    parser.add_argument('--data_list_dir',
                        type=str,
                        default='./data/data_indicator/music/solo')
    parser.add_argument('--data_dir', type=str, default='/MUSIC/solo/')
    parser.add_argument('--mode',
                        type=str,
                        default='train',
                        help='train/val/test')
    parser.add_argument('--json_file',
                        type=str,
                        default='./data/MUSIC_label/MUSIC_solo_videos.json')

    parser.add_argument('--use_class_task',
                        type=int,
                        default=1,
                        help='whether to use class task')
    parser.add_argument(
        '--init_num',
        type=int,
        default=8,
        help='epoch number for initializing the location model')
    parser.add_argument('--use_pretrain',
                        type=int,
                        default=1,
                        help='whether to init from ckpt')
    parser.add_argument('--ckpt_file',
                        type=str,
                        default='location_cluster_net_003_0.882_avg_whole.pth',
                        help='pretrained model name')
    parser.add_argument('--enable_img_augmentation',
                        type=int,
                        default=1,
                        help='whether to augment input image')
    parser.add_argument('--enable_audio_augmentation',
                        type=int,
                        default=1,
                        help='whether to augment input audio')
    parser.add_argument('--batch_size',
                        type=int,
                        default=32,
                        help='training batch size')
    parser.add_argument('--learning_rate',
                        type=float,
                        default=1e-4,
                        help='training batch size')
    parser.add_argument('--epoch',
                        type=int,
                        default=2000,
                        help='training epoch')
    parser.add_argument('--class_iter',
                        type=int,
                        default=3,
                        help='training iteration for classification model')
    parser.add_argument('--gpu_ids',
                        type=str,
                        default='[0,1,2,3]',
                        help='USING GPU IDS e.g.\'[0,4]\'')
    parser.add_argument('--num_threads',
                        type=int,
                        default=12,
                        help='number of threads')
    parser.add_argument('--seed', type=int, default=10)
    args = parser.parse_args()

    train_dataset = MUSIC_Dataset(args)
    args_test = args
    args_test.mode = 'test'
    val_dataset = MUSIC_Dataset(args_test)

    train_dataloader = DataLoader(dataset=train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_threads)
    val_dataloader = DataLoader(dataset=val_dataset,
                                batch_size=args.batch_size,
                                shuffle=False,
                                num_workers=args.num_threads)

    # net setup
    visual_backbone = resnet18(modal='vision', pretrained=True)
    audio_backbone = resnet18(modal='audio')
    av_model = Location_Net_stage_two(visual_net=visual_backbone,
                                      audio_net=audio_backbone)
    if args.use_pretrain:
        PATH = os.path.join('ckpt/stage_two_cosine2/', args.ckpt_file)
        state = torch.load(PATH)
        av_model.load_state_dict(state)
        print(PATH)
    av_model_cuda = av_model.cuda()

    # fixed model
    visual_backbone_fix = resnet18(modal='vision', pretrained=True)
    audio_backbone_fix = resnet18(modal='audio')
    av_model_fix = Location_Net_stage_one(visual_net=visual_backbone_fix,
                                          audio_net=audio_backbone_fix)
    if args.use_pretrain:
        PATH = os.path.join('ckpt/stage_one_cosine2/',
                            'location_cluster_net_iter_006_av_class.pth')
        state = torch.load(PATH)
        av_model_fix.load_state_dict(state)
        print('loaded weights')
    av_model_fix_cuda = av_model_fix.cuda()

    obj_rep = np.load(
        'obj_features2/obj_feature_softmax_avg_fc_epoch_6_av_entire.npy')

    eva_location_acc = visualize_model(av_model_cuda, av_model_fix_cuda,
                                       val_dataloader, obj_rep)
def main():
    parser = argparse.ArgumentParser(description='AID_PRETRAIN')
    parser.add_argument(
        '--data_list_dir',
        type=str,
        default='/mnt/home/hudi/location_sound/data/data_indicator/music/solo')
    parser.add_argument('--data_dir',
                        type=str,
                        default='/mnt/scratch/hudi/MUSIC/solo/')
    parser.add_argument('--mode',
                        type=str,
                        default='train',
                        help='train/val/test')
    parser.add_argument(
        '--json_file',
        type=str,
        default=
        '/mnt/home/hudi/location_sound/data/MUSIC_label/MUSIC_solo_videos.json'
    )

    parser.add_argument('--use_class_task',
                        type=int,
                        default=1,
                        help='whether to use class task')
    parser.add_argument(
        '--init_num',
        type=int,
        default=0,
        help='epoch number for initializing the location model')
    parser.add_argument('--use_pretrain',
                        type=int,
                        default=0,
                        help='whether to init from ckpt')
    parser.add_argument('--ckpt_file',
                        type=str,
                        default='location_cluster_net_softmax_009_0.709.pth',
                        help='pretrained model name')
    parser.add_argument('--enable_img_augmentation',
                        type=int,
                        default=1,
                        help='whether to augment input image')
    parser.add_argument('--enable_audio_augmentation',
                        type=int,
                        default=1,
                        help='whether to augment input audio')
    parser.add_argument('--batch_size',
                        type=int,
                        default=32,
                        help='training batch size')
    parser.add_argument('--learning_rate',
                        type=float,
                        default=1e-4,
                        help='training batch size')
    parser.add_argument('--epoch',
                        type=int,
                        default=2000,
                        help='training epoch')
    parser.add_argument('--class_iter',
                        type=int,
                        default=30,
                        help='training iteration for classification model')
    parser.add_argument('--gpu_ids',
                        type=str,
                        default='[0,1,2,3]',
                        help='USING GPU IDS e.g.\'[0,4]\'')
    parser.add_argument('--num_threads',
                        type=int,
                        default=12,
                        help='number of threads')
    parser.add_argument('--seed', type=int, default=10)
    args = parser.parse_args()

    train_list_file = os.path.join(args.data_list_dir, 'solo_pairs_train.txt')
    val_list_file = os.path.join(args.data_list_dir, 'solo_pairs_val.txt')
    test_list_file = os.path.join(args.data_list_dir, 'solo_pairs_test.txt')

    train_dataset = MUSIC_Dataset(args.data_dir, train_list_file, args)
    val_dataset = MUSIC_Dataset(args.data_dir, val_list_file, args)
    test_dataset = MUSIC_Dataset(args.data_dir, test_list_file, args)

    train_dataloader = DataLoader(dataset=train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_threads)
    val_dataloader = DataLoader(dataset=val_dataset,
                                batch_size=args.batch_size,
                                shuffle=False,
                                num_workers=args.num_threads)
    test_dataloader = DataLoader(dataset=test_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=args.num_threads)

    # net setup
    visual_backbone = resnet18(modal='vision', pretrained=True)
    audio_backbone = resnet18(modal='audio')
    av_model = Attention_Net(visual_net=visual_backbone,
                             audio_net=audio_backbone)

    if args.use_pretrain:
        PATH = os.path.join('ckpt/stage_one/', args.ckpt_file)
        state = torch.load(PATH)
        av_model.load_state_dict(state, strict=False)
    av_model_cuda = av_model.cuda()

    loss_func = ContrastiveLoss()

    optimizer = optim.Adam(params=av_model_cuda.parameters(),
                           lr=args.learning_rate,
                           betas=(0.9, 0.999),
                           weight_decay=0.0001)

    init_num = 0
    for e in range(args.epoch):
        print('Epoch is %03d' % e)

        train_location_acc = location_model_train(av_model_cuda,
                                                  train_dataloader, optimizer,
                                                  loss_func)
        eva_location_acc = location_model_eva(av_model_cuda, val_dataloader)

        print('train acc is %.3f, val acc is %.3f' %
              (train_location_acc, eva_location_acc))
        init_num += 1
        if e % 3 == 0:
            PATH = 'ckpt/att/att_stage_one_%03d_%.3f_rand.pth' % (
                e, eva_location_acc)
            torch.save(av_model_cuda.state_dict(), PATH)
def main():
    parser = argparse.ArgumentParser(description='AID_PRETRAIN')
    parser.add_argument('--data_list_dir',
                        type=str,
                        default='./data/data_indicator/music/solo')
    parser.add_argument('--data_dir',
                        type=str,
                        default='/home/ruiq/Music/solo')
    parser.add_argument('--mode',
                        type=str,
                        default='train',
                        help='train/val/test')
    parser.add_argument('--json_file',
                        type=str,
                        default='./data/MUSIC_label/MUSIC_solo_videos.json')

    parser.add_argument('--use_class_task',
                        type=int,
                        default=1,
                        help='whether to use class task')
    parser.add_argument(
        '--init_num',
        type=int,
        default=1,
        help='epoch number for initializing the location model')
    parser.add_argument('--use_pretrain',
                        type=int,
                        default=0,
                        help='whether to init from ckpt')
    parser.add_argument('--ckpt_file',
                        type=str,
                        default='location_cluster_net_norm_006_0.680.pth',
                        help='pretrained model name')
    parser.add_argument('--enable_img_augmentation',
                        type=int,
                        default=1,
                        help='whether to augment input image')
    parser.add_argument('--enable_audio_augmentation',
                        type=int,
                        default=1,
                        help='whether to augment input audio')
    parser.add_argument('--batch_size',
                        type=int,
                        default=32,
                        help='training batch size')
    parser.add_argument('--learning_rate',
                        type=float,
                        default=1e-4,
                        help='training batch size')
    parser.add_argument('--epoch', type=int, default=15, help='training epoch')
    parser.add_argument('--class_iter',
                        type=int,
                        default=3,
                        help='training iteration for classification model')
    parser.add_argument('--gpu_ids',
                        type=str,
                        default='[0,1,2,3]',
                        help='USING GPU IDS e.g.\'[0,4]\'')
    parser.add_argument('--num_threads',
                        type=int,
                        default=12,
                        help='number of threads')
    parser.add_argument('--seed', type=int, default=10)
    parser.add_argument('--cluster', type=int, default=11)
    parser.add_argument('--mask', type=float, default=0.05)
    args = parser.parse_args()

    train_list_file = os.path.join(args.data_list_dir, 'solo_training_1.txt')
    val_list_file = os.path.join(args.data_list_dir, 'solo_pairs_val.txt')
    test_list_file = os.path.join(args.data_list_dir, 'solo_pairs_test.txt')

    train_dataset = MUSIC_Dataset(args.data_dir, train_list_file, args)
    val_dataset = MUSIC_Dataset(args.data_dir, val_list_file, args)
    test_dataset = MUSIC_Dataset(args.data_dir, test_list_file, args)

    train_dataloader = DataLoader(dataset=train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_threads)
    val_dataloader = DataLoader(dataset=val_dataset,
                                batch_size=args.batch_size,
                                shuffle=False,
                                num_workers=args.num_threads)
    test_dataloader = DataLoader(dataset=test_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=args.num_threads)

    # net setup
    visual_backbone = resnet18(modal='vision', pretrained=True)
    audio_backbone = resnet18(modal='audio')
    av_model = Location_Net_stage_one(visual_net=visual_backbone,
                                      audio_net=audio_backbone,
                                      cluster=args.cluster)

    av_model_cuda = av_model.cuda()

    if args.use_pretrain:
        PATH = args.ckpt_file
        state = torch.load(PATH)
        av_model_cuda.load_state_dict(state, strict=False)
        print('loaded weights')
    else:
        av_model_cuda.conv_av.weight.data.fill_(5.)

    if args.mode == 'test':
        location_model_test(av_model_cuda, test_dataloader)
        return

    loss_func_BCE_location = torch.nn.BCELoss(reduce=True)
    loss_func_BCE_class = torch.nn.CrossEntropyLoss(
    )  # torch.nn.BCELoss(reduce=True)

    params = list(av_model_cuda.parameters())
    optimizer_location = optim.Adam(params=params[:-4],
                                    lr=args.learning_rate,
                                    betas=(0.9, 0.999),
                                    weight_decay=0.0001)

    init_num = 0
    # obj_features, img_features, aud_features, labels, img_dirs, aud_dirs \
    #     = extract_feature(av_model_cuda, train_dataloader)
    # np.save('img_feature', img_features)
    # np.save('obj_feature', obj_features)
    # np.save('labels', labels)
    # return

    for e in range(args.epoch):

        ############################### location training #################################
        print('Epoch is %03d' % e)
        train_location_acc = location_model_train(av_model_cuda,
                                                  train_dataloader,
                                                  optimizer_location,
                                                  loss_func_BCE_location)
        eva_location_acc = location_model_eva(av_model_cuda, val_dataloader)

        print('train acc is %.3f, val acc is %.3f' %
              (train_location_acc, eva_location_acc))
        init_num += 1
        if e % 1 == 0:
            ee = e
            PATH = 'ckpt/stage_one_%.2f_%d/location_cluster_net_%03d_%.3f_av_local.pth' % (
                args.mask, args.cluster, ee, eva_location_acc)
            torch.save(av_model_cuda.state_dict(), PATH)

        ############################### classification training #################################
        if init_num > args.init_num and args.use_class_task:

            obj_features, img_features, aud_features, labels, img_dirs, aud_dirs = extract_feature(
                av_model_cuda, train_dataloader, args.mask)
            val_obj_features, val_img_features, val_aud_features, val_labels, val_img_dirs, val_aud_dirs = extract_feature(
                av_model_cuda, val_dataloader, args.mask)

            obj_features_ = normalize(obj_features, norm='l2')
            aud_features_ = normalize(aud_features, norm='l2')
            av_features = np.concatenate((obj_features_, aud_features_),
                                         axis=1)

            val_obj_features_ = normalize(val_obj_features, norm='l2')
            val_aud_features_ = normalize(val_aud_features, norm='l2')
            val_av_features = np.concatenate(
                (val_obj_features_, val_aud_features_), axis=1)

            pseudo_label, nmi_score, val_pseudo_label = feature_clustering(
                obj_features, labels, val_obj_features, args.cluster)
            print('obj_NMI is %.3f' % nmi_score)

            obj_fea = []
            for i in range(args.cluster):
                cur_idx = pseudo_label == i
                cur_fea = obj_features[cur_idx]
                obj_fea.append(np.mean(cur_fea, 0))
            ee = e
            np.save(
                'obj_features_%.2f_%d/obj_feature_softmax_avg_fc_epoch_%d_av_entire.npy'
                % (args.mask, args.cluster, ee), obj_fea)

            cluster_dict = {}
            cluster_dict['pseudo_label'] = pseudo_label
            cluster_dict['gt_labels'] = labels
            cluster_ptr = open(
                'obj_features_%.2f_%d/cluster_%d.pkl' %
                (args.mask, args.cluster, ee), 'wb')
            pickle.dump(cluster_dict, cluster_ptr)

            class_dataset = MUSIC_AV_Classify(img_dirs, aud_dirs, pseudo_label,
                                              args)
            class_dataloader = DataLoader(dataset=class_dataset,
                                          batch_size=args.batch_size,
                                          shuffle=True,
                                          num_workers=args.num_threads)
            class_dataset_val = MUSIC_AV_Classify(val_img_dirs, val_aud_dirs,
                                                  val_pseudo_label, args)
            class_dataloader_val = DataLoader(dataset=class_dataset_val,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=args.num_threads)

            class_model_train(av_model_cuda, class_dataloader,
                              optimizer_location, loss_func_BCE_class, args)
            class_model_val(av_model_cuda, class_dataloader_val)

            if e % 1 == 0:
                ee = e
                PATH = 'ckpt/stage_one_%.2f_%d/location_cluster_net_iter_%03d_av_class.pth' % (
                    args.mask, args.cluster, ee)
                torch.save(av_model_cuda.state_dict(), PATH)
示例#4
0
def main():
    parser = argparse.ArgumentParser(description='AID_PRETRAIN')
    parser.add_argument(
        '--data_list_dir',
        type=str,
        default='/mnt/home/hudi/location_sound/data/data_indicator/music/solo')
    parser.add_argument('--data_dir',
                        type=str,
                        default='/mnt/scratch/hudi/MUSIC/solo/')
    parser.add_argument('--mode',
                        type=str,
                        default='test',
                        help='train/val/test')
    parser.add_argument(
        '--json_file',
        type=str,
        default=
        '/mnt/home/hudi/location_sound/data/MUSIC_label/MUSIC_solo_videos.json'
    )

    parser.add_argument('--use_class_task',
                        type=int,
                        default=1,
                        help='whether to use class task')
    parser.add_argument(
        '--init_num',
        type=int,
        default=0,
        help='epoch number for initializing the location model')
    parser.add_argument('--use_pretrain',
                        type=int,
                        default=1,
                        help='whether to init from ckpt')
    parser.add_argument('--ckpt_file',
                        type=str,
                        default='att_stage_one_024_0.812_rand.pth',
                        help='pretrained model name')
    parser.add_argument('--enable_img_augmentation',
                        type=int,
                        default=1,
                        help='whether to augment input image')
    parser.add_argument('--enable_audio_augmentation',
                        type=int,
                        default=1,
                        help='whether to augment input audio')
    parser.add_argument('--batch_size',
                        type=int,
                        default=32,
                        help='training batch size')
    parser.add_argument('--learning_rate',
                        type=float,
                        default=1e-4,
                        help='training batch size')
    parser.add_argument('--epoch',
                        type=int,
                        default=2000,
                        help='training epoch')
    parser.add_argument('--class_iter',
                        type=int,
                        default=3,
                        help='training iteration for classification model')
    parser.add_argument('--gpu_ids',
                        type=str,
                        default='[0,1,2,3]',
                        help='USING GPU IDS e.g.\'[0,4]\'')
    parser.add_argument('--num_threads',
                        type=int,
                        default=12,
                        help='number of threads')
    parser.add_argument('--seed', type=int, default=10)
    args = parser.parse_args()

    if args.init_num != 0 and args.use_pretrain:
        import sys
        print('If use ckpt, do not recommend to set init_num to 0.')
        sys.exit()

    val_list_file = os.path.join(args.data_list_dir, 'solo_pairs_val.txt')
    test_list_file = os.path.join(args.data_list_dir, 'solo_pairs_test.txt')

    val_dataset = MUSIC_Dataset(args.data_dir, val_list_file, args)
    test_dataset = MUSIC_Dataset(args.data_dir, test_list_file, args)

    val_dataloader = DataLoader(dataset=val_dataset,
                                batch_size=args.batch_size,
                                shuffle=False,
                                num_workers=args.num_threads)
    test_dataloader = DataLoader(dataset=test_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=args.num_threads)

    # net setup
    visual_backbone = resnet18(modal='vision')
    audio_backbone = resnet18(modal='audio')
    av_model = Attention_Net(visual_net=visual_backbone,
                             audio_net=audio_backbone)

    if args.use_pretrain:
        PATH = os.path.join('ckpt/att/', args.ckpt_file)
        state = torch.load(PATH)
        av_model.load_state_dict(state)
        print(PATH)
    av_model_cuda = av_model.cuda()

    location_model_eva(av_model_cuda, val_dataloader)
示例#5
0
def main():
    parser = argparse.ArgumentParser(description='AID_PRETRAIN')
    parser.add_argument('--data_list_dir',
                        type=str,
                        default='./data/data_indicator/music/solo')
    parser.add_argument('--data_dir',
                        type=str,
                        default='/home/ruiq/MUSIC/solo/')
    parser.add_argument('--mode',
                        type=str,
                        default='train',
                        help='train/val/test')
    parser.add_argument('--json_file',
                        type=str,
                        default='./data/MUSIC_label/MUSIC_solo_videos.json')
    parser.add_argument('--weight',
                        type=float,
                        default=0.5,
                        help='weight for location loss and category loss')
    parser.add_argument('--use_class_task',
                        type=int,
                        default=1,
                        help='whether to use class task')
    parser.add_argument(
        '--init_num',
        type=int,
        default=8,
        help='epoch number for initializing the location model')
    parser.add_argument('--use_pretrain',
                        type=int,
                        default=1,
                        help='whether to init from ckpt')
    parser.add_argument('--ckpt_file',
                        type=str,
                        default='location_cluster_net_iter_006_av_class.pth',
                        help='pretrained model name')
    parser.add_argument('--enable_img_augmentation',
                        type=int,
                        default=1,
                        help='whether to augment input image')
    parser.add_argument('--enable_audio_augmentation',
                        type=int,
                        default=1,
                        help='whether to augment input audio')
    parser.add_argument('--batch_size',
                        type=int,
                        default=32,
                        help='training batch size')
    parser.add_argument('--learning_rate',
                        type=float,
                        default=1e-4,
                        help='training batch size')
    parser.add_argument('--epoch', type=int, default=5, help='training epoch')
    parser.add_argument('--class_iter',
                        type=int,
                        default=3,
                        help='training iteration for classification model')
    parser.add_argument('--gpu_ids',
                        type=str,
                        default='[0,1,2,3]',
                        help='USING GPU IDS e.g.\'[0,4]\'')
    parser.add_argument('--num_threads',
                        type=int,
                        default=12,
                        help='number of threads')
    parser.add_argument('--seed', type=int, default=10)
    parser.add_argument('--cluster', type=int, default=11)
    parser.add_argument('--mask', type=float, default=0.05)
    args = parser.parse_args()

    weight = args.weight

    train_dataset = MUSIC_Dataset(args)
    args_test = args
    args_test.mode = 'val'
    val_dataset = MUSIC_Dataset(args_test)

    train_dataloader = DataLoader(dataset=train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_threads)
    val_dataloader = DataLoader(dataset=val_dataset,
                                batch_size=args.batch_size,
                                shuffle=False,
                                num_workers=args.num_threads)

    # net setup
    visual_backbone = resnet18(modal='vision', pretrained=True)
    audio_backbone = resnet18(modal='audio')
    av_model = Location_Net_stage_two(visual_net=visual_backbone,
                                      audio_net=audio_backbone,
                                      cluster=args.cluster)
    if args.use_pretrain:
        PATH = args.ckpt_file
        state = torch.load(PATH)
        av_model.load_state_dict(state, strict=False)
        print(PATH)

    av_model_cuda = av_model.cuda()

    # fixed model
    visual_backbone_fix = resnet18(modal='vision', pretrained=True)
    audio_backbone_fix = resnet18(modal='audio')
    av_model_fix = Location_Net_stage_one(visual_net=visual_backbone_fix,
                                          audio_net=audio_backbone_fix,
                                          cluster=args.cluster)
    if args.use_pretrain:
        PATH = args.ckpt_file
        state = torch.load(PATH)
        av_model_fix.load_state_dict(state)
        print('loaded weights')
    av_model_fix_cuda = av_model_fix.cuda()

    obj_rep = np.load(
        'obj_features_%.2f_%d/obj_feature_softmax_avg_fc_epoch_10_av_entire.npy'
        % (args.mask, args.cluster))

    loss_func_BCE_location = torch.nn.BCELoss(reduce=True)
    loss_func_BCE_category = torch.nn.KLDivLoss(reduce=True)

    optimizer = optim.Adam(params=av_model_cuda.parameters(),
                           lr=args.learning_rate,
                           betas=(0.9, 0.999),
                           weight_decay=0.0001)

    init_num = 0
    for e in range(args.epoch):
        print('Epoch is %03d' % e)

        train_location_acc = location_model_train(
            av_model_cuda, av_model_fix_cuda, train_dataloader, optimizer,
            loss_func_BCE_location, loss_func_BCE_category, e, obj_rep, weight)

        eva_location_acc = location_model_eva(av_model_cuda, val_dataloader,
                                              obj_rep)

        print('train acc is %.3f eval acc is %.3f' %
              (train_location_acc, eva_location_acc))
        init_num += 1
        if e % 1 == 0:
            PATH = 'ckpt/stage_syn_%.2f_%d/location_cluster_net_%03d_%.3f_avg_whole.pth' % (
                args.mask, args.cluster, e, train_location_acc)
            torch.save(av_model_cuda.state_dict(), PATH)