Beispiel #1
0
def main():
    parser = argparse.ArgumentParser(description='AID_PRETRAIN')
    parser.add_argument('--dataset_dir',
                        type=str,
                        default='F:\\download\\CVS_Dataset_New\\',
                        help='the path of the dataset')
    parser.add_argument('--batch_size',
                        type=int,
                        default=64,
                        help='training batch size')
    parser.add_argument('--learning_rate',
                        type=float,
                        default=1e-4,
                        help='training batch size')
    parser.add_argument('--epoch',
                        type=int,
                        default=2000,
                        help='training epoch')
    parser.add_argument('--gpu_ids',
                        type=str,
                        default='[0,1,2,3]',
                        help='USING GPU IDS e.g.\'[0,4]\'')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.9,
                        metavar='M',
                        help='SGD momentum (default: 0.9)')
    parser.add_argument('--image_net_weights',
                        type=str,
                        default='visual_model_pretrain.pt',
                        help='image net weights')
    parser.add_argument('--audio_net_weights',
                        type=str,
                        default='checkpoint59.pt',
                        help='image net weights')

    parser.add_argument('--data_dir',
                        type=str,
                        default='/mnt/scratch/hudi/soundscape/data/',
                        help='image net weights')
    parser.add_argument('--num_threads',
                        type=int,
                        default=8,
                        help='number of threads')
    parser.add_argument('--data_name', type=str, default='CVS_data_ind.pkl')
    parser.add_argument('--seed', type=int, default=10)
    parser.add_argument('--audionet_pretrain', type=int, default=0)
    parser.add_argument('--videonet_pretrain', type=int, default=0)
    parser.add_argument('--kd_weight', type=float, default=0.1)

    args = parser.parse_args()

    print('kl_model...')
    print('baseline...')
    print('kd_weight ' + str(args.kd_weight))
    print('audionet_pretrain ' + str(args.audionet_pretrain))
    print('videonet_pretrain ' + str(args.videonet_pretrain))

    (train_sample, train_label, val_sample, val_label, test_sample,
     test_label) = data_construction(args.data_dir)

    #f = open(args.data_name, 'wb')
    #data = {'train_sample':train_sample, 'train_label':train_label, 'test_sample':test_sample, 'test_label':test_label}
    #pickle.dump(data, f)
    #f.close()

    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    train_dataset = CVSDataset(args.data_dir,
                               train_sample,
                               train_label,
                               seed=args.seed,
                               event_label_name='event_label_bayes_59')
    val_dataset = CVSDataset(args.data_dir,
                             val_sample,
                             val_label,
                             seed=args.seed,
                             event_label_name='event_label_bayes_59')
    test_dataset = CVSDataset(args.data_dir,
                              test_sample,
                              test_label,
                              seed=args.seed,
                              event_label_name='event_label_bayes_59')

    train_dataloader = DataLoader(dataset=train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=False,
                                  num_workers=args.num_threads)
    val_dataloader = DataLoader(dataset=val_dataset,
                                batch_size=args.batch_size,
                                shuffle=False,
                                num_workers=args.num_threads)
    test_dataloader = DataLoader(dataset=test_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=args.num_threads)

    image_net = IMG_NET(num_classes=30)
    if args.videonet_pretrain:
        state = torch.load(args.image_net_weights)
        image_net.load_state_dict(state)

    audio_net = AUD_NET()
    if args.audionet_pretrain:
        state = torch.load(args.audio_net_weights)['model']
        audio_net.load_state_dict(state)

    # all stand up
    fusion_net = FusionNet_KD(image_net, audio_net, num_classes=13)

    gpu_ids = [i for i in range(4)]
    fusion_net_cuda = torch.nn.DataParallel(fusion_net,
                                            device_ids=gpu_ids).cuda()

    loss_func_CE = torch.nn.CrossEntropyLoss()
    loss_func_BCE = torch.nn.BCELoss(reduce=True)
    loss_func_MSE = torch.nn.MSELoss()

    optimizer = optim.Adam(params=fusion_net_cuda.parameters(),
                           lr=args.learning_rate,
                           betas=(0.9, 0.999),
                           weight_decay=0.0001)

    max_fscore = 0.
    count = 0
    for e in range(args.epoch):

        fusion_net_cuda.train()
        begin_time = datetime.datetime.now()

        scene_loss = 0.0
        event_loss = 0.0
        batch_num = int(len(train_dataloader.dataset) / args.batch_size)

        for i, data in enumerate(train_dataloader, 0):
            # print('batch:%d/%d' % (i,batch_num))
            img, aud, scene_label, event_label, _ = data
            img, aud, scene_label, event_label = img.type(
                torch.FloatTensor).cuda(), aud.type(
                    torch.FloatTensor).cuda(), scene_label.type(
                        torch.LongTensor).cuda(), event_label.type(
                            torch.FloatTensor).cuda()

            optimizer.zero_grad()

            scene_output, KD_output = fusion_net_cuda(img, aud)
            CE_loss = loss_func_CE(scene_output, scene_label)
            BCE_loss = loss_func_BCE(KD_output, event_label) * args.kd_weight

            #CE_loss.backward(retain_graph=True)
            #MSE_loss.backward()
            losses = CE_loss + BCE_loss
            losses.backward()
            optimizer.step()

            scene_loss += CE_loss.cpu()
            event_loss += BCE_loss.cpu()

        end_time = datetime.datetime.now()
        delta_time = (end_time - begin_time)
        delta_seconds = (delta_time.seconds * 1000 +
                         delta_time.microseconds) / 1000

        (val_acc, val_precision, val_recall,
         val_fscore) = net_test(fusion_net_cuda, val_dataloader)
        print(
            'epoch:%d scene loss:%.4f event loss:%.4f val acc:%.4f val_precision:%.4f val_recall:%.4f val_fscore:%.4f '
            % (e, scene_loss.cpu(), event_loss.cpu(), val_acc, val_precision,
               val_recall, val_fscore))
        if val_fscore > max_fscore:
            count = 0
            max_fscore = val_fscore
            (test_acc, test_precision, test_recall,
             test_fscore) = net_test(fusion_net_cuda, test_dataloader)
            test_acc_list = [test_acc]
            test_precision_list = [test_precision]
            test_recall_list = [test_recall]
            test_fscore_list = [test_fscore]
            print('mark...')
            #print('test acc:%.4f precision:%.4f recall:%.4f fscore:%.4f' % (test_acc, test_precision, test_recall, test_fscore))
        else:
            count = count + 1
            (test_acc, test_precision, test_recall,
             test_fscore) = net_test(fusion_net_cuda, test_dataloader)

            test_acc_list.append(test_acc)
            test_precision_list.append(test_precision)
            test_recall_list.append(test_recall)
            test_fscore_list.append(test_fscore)

        if count == 5:
            test_acc_mean = np.mean(test_acc_list)
            test_acc_std = np.std(test_acc_list)

            test_precision_mean = np.mean(test_precision_list)
            test_precision_std = np.std(test_precision_list)

            test_recall_mean = np.mean(test_recall_list)
            test_recall_std = np.std(test_recall_list)

            test_fscore_mean = np.mean(test_fscore_list)
            test_fscore_std = np.std(test_fscore_list)

            print(
                'test acc:%.4f (%.4f) precision:%.4f (%.4f) recall:%.4f (%.4f) fscore:%.4f(%.4f)'
                % (test_acc_mean, test_acc_std, test_precision_mean,
                   test_precision_std, test_recall_mean, test_recall_std,
                   test_fscore_mean, test_fscore_std))
            count = 0

        if e in [30, 60, 90]:
            decrease_learning_rate(optimizer, 0.1)
            print('decreased learning rate by 0.1')
Beispiel #2
0
def main():
    parser = argparse.ArgumentParser(description='AID_PRETRAIN')
    parser.add_argument('--dataset_dir',
                        type=str,
                        default='F:\\download\\CVS_Dataset_New\\',
                        help='the path of the dataset')
    parser.add_argument('--batch_size',
                        type=int,
                        default=64,
                        help='training batch size')
    parser.add_argument('--learning_rate',
                        type=float,
                        default=1e-5,
                        help='training batch size')
    parser.add_argument('--epoch',
                        type=int,
                        default=2000,
                        help='training epoch')
    parser.add_argument('--gpu_ids',
                        type=str,
                        default='[0,1,2,3]',
                        help='USING GPU IDS e.g.\'[0,4]\'')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.9,
                        metavar='M',
                        help='SGD momentum (default: 0.9)')
    parser.add_argument('--image_net_weights',
                        type=str,
                        default='AID_visual_pretrain.pt',
                        help='image net weights')
    parser.add_argument('--audio_net_weights',
                        type=str,
                        default='audioset_audio_pretrain.pt',
                        help='audio net weights')

    parser.add_argument('--data_dir',
                        type=str,
                        default='/mnt/scratch/hudi/soundscape/data/',
                        help='image net weights')
    parser.add_argument('--num_threads',
                        type=int,
                        default=8,
                        help='number of threads')
    parser.add_argument('--data_name', type=str, default='CVS_data_ind.pkl')
    parser.add_argument('--seed', type=int, default=10)
    parser.add_argument('--audionet_pretrain', type=int, default=1)
    parser.add_argument('--videonet_pretrain', type=int, default=1)
    parser.add_argument('--kd_weight', type=float, default=0.1)
    parser.add_argument('--reg_weight', type=float, default=0.001)

    parser.add_argument('--using_event_knowledge',
                        default=True,
                        action='store_true')
    parser.add_argument('--using_event_regularizer',
                        default=True,
                        action='store_true')

    args = parser.parse_args()

    (train_sample, train_label, val_sample, val_label, test_sample,
     test_label) = data_construction(args.data_dir)

    #f = open(args.data_name, 'wb')
    #data = {'train_sample':train_sample, 'train_label':train_label, 'test_sample':test_sample, 'test_label':test_label}
    #pickle.dump(data, f)
    #f.close()

    print('bayes model...')
    print(args.videonet_pretrain)
    print(args.audionet_pretrain)
    print(args.seed)
    print(args.kd_weight)
    print(args.reg_weight)
    print(args.using_event_knowledge)
    print(args.using_event_regularizer)
    print(args.learning_rate)

    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    train_dataset = CVSDataset(args.data_dir,
                               train_sample,
                               train_label,
                               seed=args.seed)
    val_dataset = CVSDataset(args.data_dir,
                             val_sample,
                             val_label,
                             seed=args.seed)
    test_dataset = CVSDataset(args.data_dir,
                              test_sample,
                              test_label,
                              seed=args.seed)

    train_dataloader = DataLoader(dataset=train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=False,
                                  num_workers=args.num_threads)
    val_dataloader = DataLoader(dataset=val_dataset,
                                batch_size=args.batch_size,
                                shuffle=False,
                                num_workers=args.num_threads)
    test_dataloader = DataLoader(dataset=test_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=args.num_threads)

    image_net = IMG_NET(num_classes=30)
    if args.videonet_pretrain:
        state = torch.load(args.image_net_weights)
        image_net.load_state_dict(state)

    audio_net = AUD_NET()
    if args.audionet_pretrain:
        state = torch.load(args.audio_net_weights)['model']
        audio_net.load_state_dict(state)

    # all stand up
    fusion_net = FusionNet(image_net, audio_net, num_classes=13)

    gpu_ids = [i for i in range(4)]
    fusion_net_cuda = torch.nn.DataParallel(fusion_net,
                                            device_ids=gpu_ids).cuda()

    loss_func_CE = torch.nn.CrossEntropyLoss()
    loss_func_BCE = torch.nn.BCELoss(reduce=True)
    loss_func_COSINE = cosine_loss()
    softmax_ = nn.LogSoftmax(dim=1)
    loss_func_KL = torch.nn.KLDivLoss()

    optimizer = optim.Adam(params=fusion_net_cuda.parameters(),
                           lr=args.learning_rate,
                           betas=(0.9, 0.999),
                           weight_decay=0.0001)

    max_fscore = 0.

    scene_to_event = np.load('scene_to_event_prior_59.npy')
    #scene_to_event = np.expand_dims(scene_to_event, 0)

    scene_to_event = torch.from_numpy(scene_to_event).cuda()
    #scene_to_event = torch.unsqueeze(x, 0)
    #scene_to_event = scene_to_event.repeat(64,1)
    count = 0
    for e in range(args.epoch):

        fusion_net_cuda.train()
        begin_time = datetime.datetime.now()

        scene_loss = 0.
        event_loss = 0.
        regu_loss = 0.
        batch_num = int(len(train_dataloader.dataset) / args.batch_size)

        for i, data in enumerate(train_dataloader, 0):
            # print('batch:%d/%d' % (i,batch_num))
            img, aud, scene_label, event_label, event_corr = data
            sample_num = img.shape[0]
            img, aud, scene_label, event_label, event_corr = img.type(
                torch.FloatTensor).cuda(), aud.type(
                    torch.FloatTensor).cuda(), scene_label.type(
                        torch.LongTensor).cuda(), event_label.type(
                            torch.FloatTensor).cuda(), event_corr.type(
                                torch.FloatTensor).cuda()

            #scene_to_event = np.expand_dims(scene_to_event, 0)
            #scene_to_event_ = np.tile(scene_to_event, (sample_num,1,1))
            #scene_to_event_cuda = torch.from_numpy(scene_to_event_).cuda()

            optimizer.zero_grad()

            scene_output = fusion_net_cuda(img, aud)

            CE_loss = loss_func_CE(scene_output, scene_label)

            scene_loss += CE_loss.cpu()

            if args.using_event_knowledge:
                scene_prob = torch.nn.functional.softmax(scene_output, dim=1)
                event_output = scene_prob.mm(scene_to_event)

                kl_loss = loss_func_BCE(event_output,
                                        event_label) * args.kd_weight
                #cosine_loss_ = loss_func_COSINE(event_output, event_label) * args.kd_weight
                event_loss += kl_loss.cpu()

                if args.using_event_regularizer:
                    #print('tt')
                    #regularizer_loss = loss_func_KL(softmax_(event_output), softmax_(event_label))
                    regularizer_loss = loss_func_COSINE(
                        event_output,
                        event_corr) * args.kd_weight * args.reg_weight
                    losses = CE_loss + kl_loss + regularizer_loss
                    regu_loss += regularizer_loss.cpu()
                else:

                    losses = CE_loss + kl_loss
            else:
                losses = CE_loss

            losses.backward()
            optimizer.step()

        end_time = datetime.datetime.now()
        delta_time = (end_time - begin_time)
        delta_seconds = (delta_time.seconds * 1000 +
                         delta_time.microseconds) / 1000

        (val_acc, val_precision, val_recall, val_fscore,
         _) = net_test(fusion_net_cuda, val_dataloader, scene_to_event, e)
        print(
            'epoch:%d scene loss:%.4f event loss:%.4f reg loss: %.4f val acc:%.4f val_precision:%.4f val_recall:%.4f val_fscore:%.4f '
            % (e, scene_loss.cpu(), event_loss.cpu(), regu_loss.cpu(), val_acc,
               val_precision, val_recall, val_fscore))
        if val_fscore > max_fscore:
            count = 0
            max_fscore = val_fscore
            (test_acc, test_precision, test_recall, test_fscore,
             results) = net_test(fusion_net_cuda, test_dataloader,
                                 scene_to_event, e)
            #print(results)
            test_acc_list = [test_acc]
            test_precision_list = [test_precision]
            test_recall_list = [test_recall]
            test_fscore_list = [test_fscore]
            print('mark...')
            #print('test acc:%.4f precision:%.4f recall:%.4f fscore:%.4f' % (test_acc, test_precision, test_recall, test_fscore))

        else:
            count = count + 1
            (test_acc, test_precision, test_recall, test_fscore,
             results) = net_test(fusion_net_cuda, test_dataloader,
                                 scene_to_event, e)
            #print(results)

            test_acc_list.append(test_acc)
            test_precision_list.append(test_precision)
            test_recall_list.append(test_recall)
            test_fscore_list.append(test_fscore)

        if count == 5:
            test_acc_mean = np.mean(test_acc_list)
            test_acc_std = np.std(test_acc_list)

            test_precision_mean = np.mean(test_precision_list)
            test_precision_std = np.std(test_precision_list)

            test_recall_mean = np.mean(test_recall_list)
            test_recall_std = np.std(test_recall_list)

            test_fscore_mean = np.mean(test_fscore_list)
            test_fscore_std = np.std(test_fscore_list)

            print(
                'test acc:%.4f (%.4f) precision:%.4f (%.4f) recall:%.4f (%.4f) fscore:%.4f(%.4f)'
                % (test_acc_mean, test_acc_std, test_precision_mean,
                   test_precision_std, test_recall_mean, test_recall_std,
                   test_fscore_mean, test_fscore_std))
            count = 0
            test_acc_list = []
            test_precision_list = []
            test_recall_list = []
            test_fscore_list = []
            # Save model
            MODEL_PATH = 'checkpoint2'
            MODEL_FILE = os.path.join(
                MODEL_PATH,
                'bayes_checkpoint%d_%.3f.pt' % (e, test_fscore_mean))
            state = {
                'model': fusion_net_cuda.state_dict(),
                'optimizer': optimizer.state_dict()
            }
            sys.stderr.write('Saving model to %s ...\n' % MODEL_FILE)
            torch.save(state, MODEL_FILE)

        if e in [30, 60, 90]:
            decrease_learning_rate(optimizer, 0.1)
            print('decreased learning rate by 0.1')
def main():
    parser = argparse.ArgumentParser(description='AID_PRETRAIN')
    parser.add_argument('--dataset_dir',
                        type=str,
                        default='F:\\download\\CVS_Dataset_New\\',
                        help='the path of the dataset')
    parser.add_argument('--batch_size',
                        type=int,
                        default=64,
                        help='training batch size')
    parser.add_argument('--learning_rate',
                        type=float,
                        default=1e-4,
                        help='training batch size')
    parser.add_argument('--epoch',
                        type=int,
                        default=2000,
                        help='training epoch')
    parser.add_argument('--gpu_ids',
                        type=str,
                        default='[0,1,2,3]',
                        help='USING GPU IDS e.g.\'[0,4]\'')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.9,
                        metavar='M',
                        help='SGD momentum (default: 0.9)')
    parser.add_argument('--av_net_ckpt',
                        type=str,
                        default='kd_checkpoint49_0.750.pt',
                        help='av net checkpoint')

    parser.add_argument('--data_dir',
                        type=str,
                        default='/mnt/scratch/hudi/soundscape/data/',
                        help='image net weights')
    parser.add_argument('--num_threads',
                        type=int,
                        default=8,
                        help='number of threads')
    parser.add_argument('--data_name', type=str, default='CVS_data_ind.pkl')
    parser.add_argument('--seed', type=int, default=10)

    args = parser.parse_args()

    class_name = [
        'forest', 'harbour', 'farmland', 'grassland', 'airport', 'sports land',
        'bridge', 'beach', 'residential', 'orchard', 'train station', 'lake',
        'sparse shrub land'
    ]

    print('kd_model...')

    global features_blobs

    (train_sample, train_label, val_sample, val_label, test_sample,
     test_label) = data_construction(args.data_dir)

    test_dataset = CVSDataset(args.data_dir,
                              test_sample,
                              test_label,
                              seed=args.seed,
                              enhance=False,
                              event_label_name='event_label_bayes_59')

    test_dataloader = DataLoader(dataset=test_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=args.num_threads)

    image_net = IMG_NET(num_classes=30)
    audio_net = AUD_NET()
    fusion_net = FusionNet_KD(image_net, audio_net, num_classes=13)

    fusion_net = torch.nn.DataParallel(fusion_net, device_ids=[0]).cuda()

    MODEL_PATH = 'checkpoint2'
    MODEL_FILE = os.path.join(MODEL_PATH, args.av_net_ckpt)
    state = torch.load(MODEL_FILE)
    fusion_net.load_state_dict(state['model'])

    fusion_net.eval()

    # get the softmax weight
    params = list(fusion_net.parameters())
    weight_softmax = np.squeeze(params[-4].data.cpu().numpy())

    fusion_net._modules.get('module')._modules.get('image_net')._modules.get(
        'layer4').register_forward_hook(hook_feature)

    with torch.no_grad():
        count = 0
        for i, data in enumerate(test_dataloader, 0):
            img, aud, label, _e, _r = data
            img, aud, label = img.type(torch.FloatTensor).cuda(), aud.type(
                torch.FloatTensor).cuda(), label.type(
                    torch.LongTensor).cuda()  # gpu
            logit, _ = fusion_net(img, aud)

            h_x = F.softmax(logit, dim=1).data.squeeze()
            for j in range(logit.shape[0]):
                h_x_current = h_x[j, :]
                probs, idx = h_x_current.sort(0, True)
                probs = probs.cpu().numpy()
                idx = idx.cpu().numpy()

                CAMs = returnCAM(features_blobs[0][j], weight_softmax,
                                 [label[j]])

                current_img = np.transpose(img[j].cpu().numpy(), [1, 2, 0])
                height, width, _ = current_img.shape
                heatmap = cv2.applyColorMap(
                    cv2.resize(CAMs[0], (width, height)), cv2.COLORMAP_JET)

                current_img = cv2.cvtColor(current_img, cv2.COLOR_BGR2RGB)
                result = heatmap * 0.3 + current_img * 0.5

                if not os.path.exists('cam3/kd/' + args.av_net_ckpt):
                    os.mkdir('cam3/kd/' + args.av_net_ckpt)
                #cv2.imwrite(os.path.join('cam/result/kd/', args.av_net_ckpt, class_name[label[j]]+'_%04d.jpg' % j), result)
                file_name = '%04d_' % count + class_name[
                    label[j]] + '_' + class_name[
                        idx[0]] + '_%.3f' % h_x_current[label[j]] + '.jpg'
                cv2.imwrite(
                    os.path.join('cam3/kd/', args.av_net_ckpt, file_name),
                    result)
                count += 1
            features_blobs = []
Beispiel #4
0
def main():
    parser = argparse.ArgumentParser(description='AID_PRETRAIN')
    parser.add_argument('--dataset_dir',
                        type=str,
                        default='F:\\download\\CVS_Dataset_New\\',
                        help='the path of the dataset')
    parser.add_argument('--batch_size',
                        type=int,
                        default=64,
                        help='training batch size')
    parser.add_argument('--learning_rate',
                        type=float,
                        default=1e-3,
                        help='training batch size')
    parser.add_argument('--epoch',
                        type=int,
                        default=2000,
                        help='training epoch')
    parser.add_argument('--gpu_ids',
                        type=str,
                        default='[0,1,2,3]',
                        help='USING GPU IDS e.g.\'[0,4]\'')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.9,
                        metavar='M',
                        help='SGD momentum (default: 0.9)')
    parser.add_argument('--image_net_weights',
                        type=str,
                        default='visual_model_pretrain.pt',
                        help='image net weights')
    parser.add_argument('--audio_net_weights',
                        type=str,
                        default='audio_pretrain_net.pt',
                        help='image net weights')

    parser.add_argument('--data_dir',
                        type=str,
                        default='/mnt/scratch/hudi/soundscape/data/',
                        help='image net weights')
    parser.add_argument('--num_threads',
                        type=int,
                        default=8,
                        help='number of threads')
    parser.add_argument('--data_name', type=str, default='CVS_data_ind.pkl')
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--audionet_pretrain', type=int, default=1)
    parser.add_argument('--videonet_pretrain', type=int, default=1)

    args = parser.parse_args()

    (train_sample, train_label, val_sample, val_label, test_sample,
     test_label) = data_construction(args.data_dir)

    #f = open(args.data_name, 'wb')
    #data = {'train_sample':train_sample, 'train_label':train_label, 'test_sample':test_sample, 'test_label':test_label}
    #pickle.dump(data, f)
    #f.close()

    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    train_dataset = CVSDataset(args.data_dir,
                               train_sample,
                               train_label,
                               seed=args.seed)
    val_dataset = CVSDataset(args.data_dir,
                             val_sample,
                             val_label,
                             seed=args.seed)
    test_dataset = CVSDataset(args.data_dir,
                              test_sample,
                              test_label,
                              seed=args.seed)

    train_dataloader = DataLoader(dataset=train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=False,
                                  num_workers=args.num_threads)
    val_dataloader = DataLoader(dataset=val_dataset,
                                batch_size=args.batch_size,
                                shuffle=False,
                                num_workers=args.num_threads)
    test_dataloader = DataLoader(dataset=test_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=args.num_threads)

    image_net = IMG_NET(num_classes=30)
    if args.videonet_pretrain:
        state = torch.load(args.image_net_weights)
        image_net.load_state_dict(state)

    audio_net = AUD_NET()
    if args.audionet_pretrain:
        state = torch.load(args.audio_net_weights)['model']
        audio_net.load_state_dict(state)

    # all stand up
    fusion_net = FUS_NET(image_net, audio_net, num_classes=13)

    gpu_ids = [i for i in range(4)]
    fusion_net_cuda = torch.nn.DataParallel(fusion_net,
                                            device_ids=gpu_ids).cuda()

    loss_func = torch.nn.CrossEntropyLoss()

    optimizer = optim.Adam(params=fusion_net_cuda.parameters(),
                           lr=args.learning_rate,
                           betas=(0.9, 0.999),
                           weight_decay=0.0001)

    max_fscore = 0.

    for e in range(args.epoch):

        fusion_net_cuda.train()
        begin_time = datetime.datetime.now()

        train_loss = 0.0
        batch_num = int(len(train_dataloader.dataset) / args.batch_size)

        for i, data in enumerate(train_dataloader, 0):
            # print('batch:%d/%d' % (i,batch_num))
            img, aud, label = data
            img, aud, label = img.type(torch.FloatTensor).cuda(), aud.type(
                torch.FloatTensor).cuda(), label.type(torch.LongTensor).cuda()

            optimizer.zero_grad()

            output = fusion_net_cuda(img, aud)
            loss = loss_func(output, label)
            loss.backward()
            optimizer.step()

            train_loss += loss.cpu()

        end_time = datetime.datetime.now()
        delta_time = (end_time - begin_time)
        delta_seconds = (delta_time.seconds * 1000 +
                         delta_time.microseconds) / 1000

        (val_acc, val_precision, val_recall,
         val_fscore) = net_test(fusion_net_cuda, val_dataloader)
        print(
            'epoch:%d loss:%.4f time:%.4f val acc:%.4f val_precision:%.4f val_recall:%.4f val_fscore:%.4f '
            %
            (e, train_loss.cpu(),
             (delta_seconds), val_acc, val_precision, val_recall, val_fscore))
        if val_fscore > max_fscore:
            max_fscore = val_fscore
            (test_acc, test_precision, test_recall,
             test_fscore) = net_test(fusion_net_cuda, test_dataloader)
            print('test acc:%.4f precision:%.4f recall:%.4f fscore:%.4f' %
                  (test_acc, test_precision, test_recall, test_fscore))

        if e in [30, 60, 90]:
            decrease_learning_rate(optimizer, 0.1)
            print('decreased learning rate by 0.1')