class Action_Recognition:
    def __init__(self, model_file, sample_duration, model_type, cuda_id=0):

        self.opt = parse_opts()

        self.opt.model = model_type

        self.opt.root_path = './C3D_ResNet/data'

        self.opt.resume_path = os.path.join(self.opt.root_path, model_file)
        self.opt.pretrain_path = os.path.join(self.opt.root_path, 'models/resnet-18-kinetics.pth')

        self.opt.cuda_id = cuda_id
        self.opt.dataset = 'ucf101'
        self.opt.n_classes = 400
        self.opt.n_finetune_classes = 3
        self.opt.ft_begin_index = 4
        self.opt.model_depth = 18
        self.opt.resnet_shortcut = 'A'
        self.opt.sample_duration = sample_duration
        self.opt.batch_size = 1
        self.opt.n_threads = 1
        self.opt.checkpoint = 5

        self.opt.arch = '{}-{}'.format(self.opt.model, self.opt.model_depth)
        self.opt.mean = get_mean(self.opt.norm_value, dataset=self.opt.mean_dataset)
        self.opt.std = get_std(self.opt.norm_value)
        # print(self.opt)

        print('Loading C3D action-recognition model..')

        self.model, parameters = generate_model(self.opt)
        # print(self.model)

        if self.opt.no_mean_norm and not self.opt.std_norm:
            norm_method = Normalize([0, 0, 0], [1, 1, 1])
        elif not self.opt.std_norm:
            norm_method = Normalize(self.opt.mean, [1, 1, 1])
        else:
            norm_method = Normalize(self.opt.mean, self.opt.std)

        if self.opt.resume_path:
            print('    loading checkpoint {}'.format(self.opt.resume_path))
            checkpoint = torch.load(self.opt.resume_path)
            # assert self.opt.arch == checkpoint['arch']

            self.opt.begin_epoch = checkpoint['epoch']
            self.model.load_state_dict(checkpoint['state_dict'])

        self.spatial_transform = Compose([
            ScaleQC(int(self.opt.sample_size / self.opt.scale_in_test)),
            CornerCrop(self.opt.sample_size, self.opt.crop_position_in_test),
            ToTensor(self.opt.norm_value), norm_method
        ])

        self.target_transform = ClassLabel()

        self.model.eval()

    def run(self, clip, heatmap=None):
        '''
        input: clips is continuous frames with length T and batch size N
        return: action recognition probability
        '''

        # prepare dataset
        self.spatial_transform.randomize_parameters()
        clip_all = []
        for clip_batch in clip:
            clip_all.append(torch.stack([self.spatial_transform(img) for img in clip_batch], 0))
        clip = torch.stack(clip_all, 0).permute(0, 2, 1, 3, 4)

        if self.opt.cuda_id is None:
            inputs = Variable(clip)
        else:
            inputs = Variable(clip.cuda(self.opt.cuda_id))

        if self.opt.model == 'resnet_skeleton':
            heatmap_all = []
            for heatmap_batch in heatmap:
                heatmap_all.append(torch.stack([self.spatial_transform(img) for img in heatmap_batch], 0))
            heatmap = torch.stack(heatmap_all, 0).permute(0, 2, 1, 3, 4)

            if self.opt.cuda_id is None:
                heatmap = Variable(heatmap)
            else:
                heatmap = Variable(heatmap.cuda(self.opt.cuda_id))

        # run model
        if self.opt.model == 'resnet_skeleton':
            outputs = self.model(inputs, heatmap)
        else:
            outputs = self.model(inputs)
        outputs = F.softmax(outputs, dim=1)

        sorted_scores, locs = torch.topk(outputs, k=1, dim=1)
        labels = locs.detach().cpu().numpy().tolist()
        probs = outputs.detach().cpu().numpy().tolist()
        result_labels = []
        for i, label in enumerate(labels):
            result_labels.append([label[0], probs[i]])

        print(result_labels)
        return result_labels

    def runCAM(self, out_image_name, clip, heatmap=None, is_heatmap=False):

        # hook the feature extractor
        finalconv_name = 'layer4'
        features_blobs = []
        def __hook_feature(module, input, output):
            features_blobs.append(output.data.cpu().numpy())

        print(self.model.module._modules.get(finalconv_name))
        self.model.module._modules.get(finalconv_name).register_forward_hook(__hook_feature)

        # get the softmax weight
        print(self.model)
        params = list(self.model.parameters())
        weight_softmax = np.squeeze(params[-2].data.cpu().numpy())

        def returnCAM(feature_conv, weight_softmax, class_idx):
            # generate the class activation maps upsample to 256x256
            size_upsample = (256, 256)
            bz, nc, h, w = feature_conv.shape
            output_cam = []
            for idx in class_idx:
                cam = weight_softmax[idx].dot(feature_conv.reshape((nc, h*w)))
                cam = cam.reshape(h, w)
                cam = cam - np.min(cam)
                cam_img = cam / np.max(cam)
                cam_img = np.uint8(255 * cam_img)
                output_cam.append(cv2.resize(cam_img, size_upsample))
            return output_cam

        # download the imagenet category list
        classes = {0:'clean', 1:'normal', 2:'scratch'}
        ori_cam_img_list = []
        for t in range(30):
            ori_cam_img_list.append(cv2.cvtColor(np.asarray(clip[0][t]), cv2.COLOR_BGR2RGB))

        # prepare dataset
        self.spatial_transform.randomize_parameters()
        clip_all = []
        for clip_batch in clip:
            clip_all.append(torch.stack([self.spatial_transform(img) for img in clip_batch], 0))
        clip = torch.stack(clip_all, 0).permute(0, 2, 1, 3, 4)

        if self.opt.cuda_id is None:
            inputs = Variable(clip)
        else:
            inputs = Variable(clip.cuda(self.opt.cuda_id))

        if self.opt.model == 'resnet_skeleton':
            heatmap_all = []
            for heatmap_batch in heatmap:
                heatmap_all.append(torch.stack([self.spatial_transform(img) for img in heatmap_batch], 0))
            heatmap = torch.stack(heatmap_all, 0).permute(0, 2, 1, 3, 4)

            if self.opt.cuda_id is None:
                heatmap = Variable(heatmap)
            else:
                heatmap = Variable(heatmap.cuda(self.opt.cuda_id))

        # run model
        if self.opt.model == 'resnet_skeleton':
            outputs = self.model(inputs, heatmap)
        else:
            outputs = self.model(inputs)

        h_x = F.softmax(outputs, dim=1).data.squeeze()
        probs, idx = h_x.sort(0, True)
        probs = probs.cpu().numpy()
        idx = idx.cpu().numpy()

        # output the prediction
        for i in range(0, 3):
            print('{:.3f} -> {}'.format(probs[i], classes[idx[i]]))

        # generate class activation mapping for the top1 prediction
        print(features_blobs[0].shape)
        features_blobs[0] = np.mean(features_blobs[0], axis=2)
        print(features_blobs[0].shape)
        # weight_softmax = weight_softmax[:, 512:]
        print(weight_softmax.shape)
        print(idx[0])
        CAMs = returnCAM(features_blobs[0], weight_softmax, [idx[0]])

        # render the CAM and output
        img = clip[0][0]
        height, width, _ = ori_cam_img_list[0].shape
        heatmap = cv2.applyColorMap(cv2.resize(CAMs[0],(width, height)), cv2.COLORMAP_JET)

        cv2.namedWindow('result')
        flag = True
        while flag:
            for t in range(30):
                result = heatmap * 0.3 + ori_cam_img_list[t] * 0.5
                result = result.astype('uint8')

                cv2.imshow('result', result)
                c = cv2.waitKey(0) % 256
                if c == 13:
                    flag = False
                    break

        cv2.destroyAllWindows()
        if is_heatmap:
            cv2.imwrite('CAM-skeleton.jpg', result)
            print('output CAM-skeleton.jpg for the top1 prediction: %s'%classes[idx[0]])
        else:
            cv2.imwrite('CAM.jpg', result)
            print('output CAM.jpg for the top1 prediction: %s'%classes[idx[0]])
        return
Пример #2
0
class Run:
    def __init__(self):

        self.model_methods = [['resnext', 'gradcam', 'camshow']]

        self.classes = [
            "brush_hair", "cartwheel", "catch", "chew", "clap", "climb",
            "climb_stairs", "dive", "draw_sword", "dribble", "drink", "eat",
            "fall_floor", "fencing", "flic_flac", "golf", "handstand", "hit",
            "hug", "jump", "kick", "kick_ball", "kiss", "laugh", "pick",
            "pour", "pullup", "punch", "push", "pushup", "ride_bike",
            "ride_horse", "run", "shake_hands", "shoot_ball", "shoot_bow",
            "shoot_gun", "sit", "situp", "smile", "smoke", "somersault",
            "stand", "swing_baseball", "sword", "sword_exercise", "talk",
            "throw", "turn", "walk", "wave"
        ]

        scales = [1.0]

        self.spatial_transform = Compose([
            MultiScaleCornerCrop(scales, 112),
            ToTensor(1.0),
            Normalize(get_mean(1.0, dataset='activitynet'), get_std(1.0))
        ])

        self.spatial_transform2 = Compose([MultiScaleCornerCrop(scales, 112)])

        self.spatial_transform3 = Compose([
            MultiScaleCornerCrop(scales, 112),
            ToTensor(1),
            Normalize([0, 0, 0], [1, 1, 1])
        ])

        self.model = utils.load_model(self.model_methods[0][0])
        self.model.cuda()
        #self.video=[]
        #self.flows=[]
        self.bb_frames = []
        #self.explainer= get_explainer
        method_name = 'gradcam'
        self.explainer = get_explainer(self.model, method_name, "conv1")
        self.explainer2 = get_explainer(self.model, method_name, "layer1")
        self.explainer3 = get_explainer(self.model, method_name, "layer2")
        self.explainer4 = get_explainer(self.model, method_name, "layer3")
        self.explainer5 = get_explainer(self.model, method_name, "layer4")
        self.explainer6 = get_explainer(self.model, method_name, "avgpool")
        path = "images/frames4"
        #print path
        self.path = path + "/"
        #dirc = os.listdir(path)
        #self.files = [ fname for fname in dirc if fname.startswith('img')]
        #self.files2 = [ fname for fname in dirc if fname.startswith('flow_x')]
        self.seq = []
        self.kls = []
        self.scr = []
        self.totalhit = 0
        self.totalhit2 = 0
        self.totalhit3 = 0
        self.totalhit4 = 0
        self.totalhit5 = 0
        self.totalhit6 = 0
        self.totalhit7 = 0
        self.totalframes = 0

    def myRange(self, start, end, step):
        i = start
        while i < end:
            yield i
            i += step
        yield end

    def bounding_box(points):
        x_coordinates, y_coordinates = zip(*points)
        return [(min(x_coordinates), min(y_coordinates)),
                (max(x_coordinates), max(y_coordinates))]

    def saliency(self, path):

        #path = path.replace("base2","base")
        path_gt = path.replace("base/", "base2/")

        path = "./" + path
        path_gt = "./" + path_gt
        print path_gt
        if os.path.isdir(path_gt):

            for model_name, method_name, _ in self.model_methods:

                dirc = os.listdir(path)
                dirc_gt = os.listdir(path_gt)

                files = [fname for fname in dirc if fname.endswith('png')]
                files_gt = [
                    fname for fname in dirc_gt if fname.endswith('mat')
                ]

            def bounding_box(points):
                x_coordinates, y_coordinates = zip(*points)
                return [(min(x_coordinates), min(y_coordinates)),
                        (max(x_coordinates), max(y_coordinates))]

            for index in self.myRange(0, len(files), 16):
                #print "frame ke", index, "from", len(files)
                if index == len(files):
                    continue
                video = []
                flows = []
                boxes = []

                for filename in sorted(files)[index:index + 16]:
                    video.append(Image.open(path + filename))

                diff = 0
                matfile = scipy.io.loadmat(path_gt + files_gt[0])
                coor = np.array(matfile["pos_img"]).transpose(
                    2, 1, 0).tolist()[index:index + 16]
                scale = matfile["scale"][0]

                if len(coor) == 0:
                    continue

                print len(video), len(coor)

                if len(coor) != 16 or len(video) != 16:
                    diff = len(video)
                    video = video * (16 / len(video)) * 2
                    video = video[0:16]
                    coor = coor * (16 / len(coor)) * 2
                    coor = coor[0:16]

                print len(video), len(coor)
                if len(coor) == 0:
                    continue
                for e in range(0, len(coor), 1):

                    box = bounding_box([(abs(dots[0]) * 112.0 / 320.0,
                                         abs(dots[1]) * 112.0 / 240.0)
                                        for dots in coor[e]])
                    boxes.append(box)

                self.spatial_transform.randomize_parameters()
                self.spatial_transform2.randomize_parameters()
                self.spatial_transform3.randomize_parameters()

                clip = [self.spatial_transform3(img) for img in video]
                inp = torch.stack(clip, 0).permute(1, 0, 2, 3)

                all_saliency_maps = []
                for model_name, method_name, _ in self.model_methods:

                    if method_name == 'googlenet':  # swap channel due to caffe weights
                        inp_copy = inp.clone()
                        inp[0] = inp_copy[2]
                        inp[2] = inp_copy[0]
                    inp = utils.cuda_var(inp.unsqueeze(0), requires_grad=True)

                    saliency, s, kls, scr, c = self.explainer.explain(inp)
                    saliency2, s2, kls, scr, c2 = self.explainer2.explain(inp)
                    saliency3, s3, kls, scr, c3 = self.explainer3.explain(inp)
                    saliency4, s4, kls, scr, c4 = self.explainer4.explain(inp)
                    saliency5, s5, kls, scr, c5 = self.explainer5.explain(inp)
                    saliency6, pool, kls, scr, c6 = self.explainer6.explain(
                        inp)

                    torch.cuda.empty_cache()

                    saliency = (saliency6 + saliency5 + saliency4 + saliency3 +
                                saliency2 + saliency)

                    if self.classes[kls] == path.split("/")[5]:
                        label = 1
                    else:
                        label = 0

                    saliency = torch.clamp(saliency, min=0)

                    temp = saliency.shape[2]

                    if temp > 1:
                        all_saliency_maps.append(
                            saliency.squeeze().cpu().data.numpy())
                    else:
                        all_saliency_maps.append(
                            saliency.squeeze().unsqueeze(0).cpu().numpy())

                    del pool, inp, saliency, saliency6
                    torch.cuda.empty_cache()

                plt.figure(figsize=(50, 5))

                for k in range(len(video[0:(16 - diff)])):
                    hit = 0
                    hit2 = 0
                    hit3 = 0
                    hit4 = 0
                    hit5 = 0
                    hit6 = 0
                    hit7 = 0
                    plt.subplot(2, 16, k + 1)
                    img = self.spatial_transform2(video[k])

                    if len(boxes) > 0:
                        box = boxes[k]

                        x = box[0][0]
                        y = box[0][1]
                        w = box[1][0] - x
                        h = box[1][1] - y

                        if (x + w) > 112:
                            w = (112 - x)
                        if (y + h) > 112:
                            h = (112 - y)

                        ax = viz.plot_bbox([x, y, w, h], img)

                    plt.axis('off')
                    ax = plt.gca()

                    ax.imshow(img)
                    sal = all_saliency_maps[0][k]
                    sal = (sal - np.mean(sal)) / np.std(sal)
                    ret, thresh = cv2.threshold(
                        sal,
                        np.mean(sal) + ((np.amax(sal) - np.mean(sal)) * 0.5),
                        1, cv2.THRESH_BINARY)

                    contours, hierarchy = cv2.findContours(
                        thresh.astype(np.uint8), 1, 2)

                    areas = [cv2.contourArea(c) for c in contours]

                    if len(contours) > 0:

                        glob = np.array(
                            [cv2.boundingRect(cnt) for cnt in contours])
                        #print glob.shape
                        x3 = np.amin(glob[:, 0])
                        y3 = np.amin(glob[:, 1])
                        x13 = np.amax(glob[:, 0] + glob[:, 2])
                        y13 = np.amax(glob[:, 1] + glob[:, 3])

                        rect3 = patches.Rectangle((x3, y3),
                                                  x13 - x3,
                                                  y13 - y3,
                                                  linewidth=2,
                                                  edgecolor='y',
                                                  facecolor='none')
                        ax.add_patch(rect3)

                        for cnt in contours:
                            x2, y2, w2, h2 = cv2.boundingRect(cnt)

                            rect2 = patches.Rectangle((x2, y2),
                                                      w2,
                                                      h2,
                                                      linewidth=2,
                                                      edgecolor='r',
                                                      facecolor='none')
                            ax.add_patch(rect2)

                        overlap = nms.get_iou([x, x + w, y, y + h],
                                              [x3, x13, y3, y13])

                        if label == 1:
                            if overlap >= 0.6:
                                hit = 1
                            if overlap >= 0.5:
                                hit2 = 1
                            if overlap >= 0.4:
                                hit3 = 1
                            if overlap >= 0.3:
                                hit4 = 1
                            if overlap >= 0.2:
                                hit5 = 1
                            if overlap >= 0.1:
                                hit6 = 1
                            if overlap > 0.0:
                                hit7 = 1

                    self.totalhit += hit
                    self.totalhit2 += hit2
                    self.totalhit3 += hit3
                    self.totalhit4 += hit4
                    self.totalhit5 += hit5
                    self.totalhit6 += hit6
                    self.totalhit7 += hit7
                    self.totalframes += 1
                    print "================="
                    print "accuracy0.6=", float(
                        self.totalhit) / self.totalframes
                    print "accuracy0.5=", float(
                        self.totalhit2) / self.totalframes
                    print "accuracy0.4=", float(
                        self.totalhit3) / self.totalframes
                    print "accuracy0.3=", float(
                        self.totalhit4) / self.totalframes
                    print "accuracy0.2=", float(
                        self.totalhit5) / self.totalframes
                    print "accuracy0.1=", float(
                        self.totalhit6) / self.totalframes
                    print "accuracy0.0=", float(
                        self.totalhit7) / self.totalframes

                    for saliency in all_saliency_maps:
                        show_style = 'camshow'

                        plt.subplot(2, 16, k + 17)
                        if show_style == 'camshow':

                            viz.plot_cam(np.abs(saliency[k]).squeeze(),
                                         img,
                                         'jet',
                                         alpha=0.5)

                            plt.axis('off')
                            plt.title(float(np.average(saliency[k])))

                            self.seq.append(
                                np.array(np.expand_dims(saliency[k], axis=2)) *
                                np.array(img))

                        else:
                            if model_name == 'googlenet' or method_name == 'pattern_net':
                                saliency = saliency.squeeze()[::-1].transpose(
                                    1, 2, 0)
                            else:
                                saliency = saliency.squeeze().transpose(
                                    1, 2, 0)
                            saliency -= saliency.min()
                            saliency /= (saliency.max() + 1e-20)
                            plt.imshow(saliency, cmap='gray')

                        if method_name == 'excitation_backprop':
                            plt.title('Exc_bp')
                        elif method_name == 'contrastive_excitation_backprop':
                            plt.title('CExc_bp')
                        else:
                            plt.title('%s' % (method_name))

                plt.tight_layout()
                print path.split("/")

                plt.savefig('./embrace_%i_%s.png' %
                            (index, path.split("/")[-2]))

            torch.cuda.empty_cache()

            print path.split("/")

            return self.seq, self.kls, self.scr
    def score(self):

        normalize = get_normalize_method(self.opt.mean, self.opt.std, self.opt.no_mean_norm,
                                         self.opt.no_std_norm)
        spatial_transform = [
            Resize(self.opt.sample_size),
            CenterCrop(self.opt.sample_size),
            ToTensor()
        ]

        spatial_transform.extend([ScaleValue(self.opt.value_scale), normalize])
        spatial_transform = Compose(spatial_transform)

        temporal_transform = []
        if self.opt.sample_t_stride > 1:
            temporal_transform.append(TemporalSubsampling(self.opt.sample_t_stride))
        temporal_transform.append(
            TemporalEvenCrop(self.opt.sample_duration, self.opt.n_val_samples))
        temporal_transform = TemporalCompose(temporal_transform)


        frame_count = get_n_frames(self.opt.video_jpgs_dir_path)

        frame_indices = list(range(0, frame_count))

        frame_indices = temporal_transform(frame_indices)

        spatial_transform.randomize_parameters()

        image_name_formatter = lambda x: f'image_{x:05d}.jpg'

        loader = VideoLoader(image_name_formatter)

        print('frame_indices', frame_indices)

        #clips = []
        video_outputs = []
        model = generate_model(self.opt)


        model = load_pretrained_model(model, self.opt.pretrain_path, self.opt.model,
                                      self.opt.n_finetune_classes)

        i =0
        for frame_indice in frame_indices:
            print("%d indice: %s" % (i, str(frame_indice)))
            i+=1

            clip = loader(self.opt.video_jpgs_dir_path, frame_indice)



            clip = [spatial_transform(img) for img in clip]
            clip = torch.stack(clip, 0).permute(1, 0, 2, 3)





            #parameters = get_fine_tuning_parameters(model, opt.ft_begin_module)


            #print('clips:', clips)


            #for clip in clips:
            with torch.no_grad():

                print(clip.shape)
                output = model(torch.unsqueeze(clip, 0))
                output = F.softmax(output, dim=1).cpu()

                #print(output)
                video_outputs.append(output[0])

            del clip

        video_outputs = torch.stack(video_outputs)
        average_scores = torch.mean(video_outputs, dim=0)

        #inference_loader, inference_class_names = main.get_inference_utils(self.opt)
        with self.opt.annotation_path.open('r') as f:
            data = json.load(f)

        class_to_idx = get_class_labels(data)
        idx_to_class = {}
        for name, label in class_to_idx.items():
            idx_to_class[label] = name
        print(idx_to_class)

        inference_result = inference.get_video_results(
            average_scores, idx_to_class, self.opt.output_topk)

        print(inference_result)
Пример #4
0
def main(configs):

    from utils.generate_dataset_jsons import generate_dataset_jsons
    generate_dataset_jsons(configs.dataset_folder)
    with open('dataset_rgb_train.json') as json_file:
        dataset_rgb_train_json = json.load(json_file)
    with open('dataset_rgb_valid.json') as json_file:
        dataset_rgb_valid_json = json.load(json_file)
    with open('dataset_mmaps_train.json') as json_file:
        dataset_mmaps_train_json = json.load(json_file)
    with open('dataset_mmaps_valid.json') as json_file:
        dataset_mmaps_valid_json = json.load(json_file)
    with open('dataset_flow_train.json') as json_file:
        dataset_flow_train_json = json.load(json_file)
    with open('dataset_flow_valid.json') as json_file:
        dataset_flow_valid_json = json.load(json_file)

    torch.backends.cudnn.benchmark = True

    if os.path.exists(configs.output_folder):
        print('Warning: output folder {} already exists!'.format(
            configs.output_folder))
    try:
        os.makedirs(configs.output_folder)
    except FileExistsError:
        pass

    normalize = Normalize(mean=[0.485, 0.456, 0.406],
                          std=[0.229, 0.224, 0.225])
    normalize_5 = Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    randomHorizontalFlip = RandomHorizontalFlip()
    randomMultiScaleCornerCrop = MultiScaleCornerCrop(
        [1, 0.875, 0.75, 0.65625], 224)

    train_transforms = Compose([
        randomHorizontalFlip, randomMultiScaleCornerCrop,
        ToTensor(), normalize
    ])

    train_transforms.randomize_parameters()

    valid_transforms = Compose([CenterCrop(224), ToTensor(), normalize])

    train_transforms_mmaps = Compose([
        randomHorizontalFlip, randomMultiScaleCornerCrop,
        Scale(7),
        ToTensor(),
        Binarize()
    ])

    valid_transforms_mmaps = Compose(
        [CenterCrop(224), Scale(7),
         ToTensor(), Binarize()])

    train_transforms_5 = Compose([
        randomHorizontalFlip, randomMultiScaleCornerCrop,
        ToTensor(), normalize_5
    ])

    valid_transforms_5 = Compose([CenterCrop(224), ToTensor(), normalize_5])

    print('Loading dataset')

    if configs.dataset == 'DatasetRGB':
        dataset_rgb_train = DatasetRGB(configs.dataset_folder,
                                       dataset_rgb_train_json,
                                       spatial_transform=train_transforms,
                                       seqLen=configs.dataset_rgb_n_frames,
                                       minLen=5)
        dataset_rgb_valid = DatasetRGB(configs.dataset_folder,
                                       dataset_rgb_valid_json,
                                       spatial_transform=valid_transforms,
                                       seqLen=configs.dataset_rgb_n_frames,
                                       minLen=5,
                                       device='cuda')

        dataset_train = dataset_rgb_train
        dataset_valid = dataset_rgb_valid

        forward_fn = forward_rgb
    elif configs.dataset == 'DatasetRGBMMAPS':
        dataset_rgb_train = DatasetRGB(configs.dataset_folder,
                                       dataset_rgb_train_json,
                                       spatial_transform=train_transforms,
                                       seqLen=configs.dataset_rgb_n_frames,
                                       minLen=5)
        dataset_rgb_valid = DatasetRGB(configs.dataset_folder,
                                       dataset_rgb_valid_json,
                                       spatial_transform=valid_transforms,
                                       seqLen=configs.dataset_rgb_n_frames,
                                       minLen=5,
                                       device='cuda')
        dataset_mmaps_train = DatasetMMAPS(
            configs.dataset_folder,
            dataset_mmaps_train_json,
            spatial_transform=train_transforms_mmaps,
            seqLen=configs.dataset_rgb_n_frames,
            minLen=1,
            enable_randomize_transform=False)
        dataset_mmaps_valid = DatasetMMAPS(
            configs.dataset_folder,
            dataset_mmaps_valid_json,
            spatial_transform=valid_transforms_mmaps,
            seqLen=configs.dataset_rgb_n_frames,
            minLen=1,
            enable_randomize_transform=False,
            device='cuda')

        dataset_rgbmmaps_train = DatasetRGBMMAPS(dataset_rgb_train,
                                                 dataset_mmaps_train)
        dataset_rgbmmaps_valid = DatasetRGBMMAPS(dataset_rgb_valid,
                                                 dataset_mmaps_valid)

        dataset_train = dataset_rgbmmaps_train
        dataset_valid = dataset_rgbmmaps_valid

        forward_fn = forward_rgbmmaps
    elif configs.dataset == 'DatasetFlow':
        dataset_flow_train = DatasetFlow(
            configs.dataset_folder,
            dataset_flow_train_json,
            spatial_transform=train_transforms,
            stack_size=configs.dataset_flow_stack_size,
            sequence_mode='single_random')
        dataset_flow_valid = DatasetFlow(
            configs.dataset_folder,
            dataset_flow_valid_json,
            spatial_transform=valid_transforms,
            stack_size=configs.dataset_flow_stack_size,
            sequence_mode='single_midtime')

        dataset_train = dataset_flow_train
        dataset_valid = dataset_flow_valid

        forward_fn = forward_flow
    elif configs.dataset == 'DatasetFlowMultiple':
        dataset_flow_train = DatasetFlow(
            configs.dataset_folder,
            dataset_flow_train_json,
            spatial_transform=train_transforms_5,
            stack_size=configs.dataset_flow_stack_size,
            n_sequences=configs.dataset_flow_n_sequences,
            sequence_mode='multiple_jittered')
        dataset_flow_valid = DatasetFlow(
            configs.dataset_folder,
            dataset_flow_valid_json,
            spatial_transform=valid_transforms_5,
            stack_size=configs.dataset_flow_stack_size,
            n_sequences=configs.dataset_flow_n_sequences,
            sequence_mode='multiple')

        dataset_train = dataset_flow_train
        dataset_valid = dataset_flow_valid

        forward_fn = forward_flowmultiple
    elif configs.dataset == 'DatasetRGBFlow':
        dataset_rgb_train = DatasetRGB(configs.dataset_folder,
                                       dataset_rgb_train_json,
                                       spatial_transform=train_transforms,
                                       seqLen=configs.dataset_rgb_n_frames,
                                       minLen=5)
        dataset_rgb_valid = DatasetRGB(configs.dataset_folder,
                                       dataset_rgb_valid_json,
                                       spatial_transform=valid_transforms,
                                       seqLen=configs.dataset_rgb_n_frames,
                                       minLen=5)
        dataset_flow_train = DatasetFlow(
            configs.dataset_folder,
            dataset_flow_train_json,
            spatial_transform=train_transforms,
            stack_size=configs.dataset_flow_stack_size,
            sequence_mode='single_random',
            enable_randomize_transform=False)
        dataset_flow_valid = DatasetFlow(
            configs.dataset_folder,
            dataset_flow_valid_json,
            spatial_transform=valid_transforms,
            stack_size=configs.dataset_flow_stack_size,
            sequence_mode='single_midtime',
            enable_randomize_transform=False)

        dataset_rgbflow_train = DatasetRGBFlow(dataset_rgb_train,
                                               dataset_flow_train)
        dataset_rgbflow_valid = DatasetRGBFlow(dataset_rgb_valid,
                                               dataset_flow_valid)

        dataset_train = dataset_rgbflow_train
        dataset_valid = dataset_rgbflow_valid

        forward_fn = forward_rgbflow
    elif configs.dataset == 'DatasetRGBFlowMultiple':
        dataset_rgb_train = DatasetRGB(configs.dataset_folder,
                                       dataset_rgb_train_json,
                                       spatial_transform=train_transforms,
                                       seqLen=configs.dataset_rgb_n_frames,
                                       minLen=5)
        dataset_rgb_valid = DatasetRGB(configs.dataset_folder,
                                       dataset_rgb_valid_json,
                                       spatial_transform=valid_transforms,
                                       seqLen=configs.dataset_rgb_n_frames,
                                       minLen=5)
        dataset_flow_train = DatasetFlow(
            configs.dataset_folder,
            dataset_flow_train_json,
            spatial_transform=train_transforms_5,
            stack_size=configs.dataset_flow_stack_size,
            n_sequences=configs.dataset_rgb_n_frames,
            sequence_mode='multiple',
            enable_randomize_transform=False)
        dataset_flow_valid = DatasetFlow(
            configs.dataset_folder,
            dataset_flow_valid_json,
            spatial_transform=valid_transforms_5,
            stack_size=configs.dataset_flow_stack_size,
            n_sequences=configs.dataset_rgb_n_frames,
            sequence_mode='multiple',
            enable_randomize_transform=False)

        dataset_rgbflow_train = DatasetRGBFlow(dataset_rgb_train,
                                               dataset_flow_train)
        dataset_rgbflow_valid = DatasetRGBFlow(dataset_rgb_valid,
                                               dataset_flow_valid)

        dataset_train = dataset_rgbflow_train
        dataset_valid = dataset_rgbflow_valid

        forward_fn = forward_rgbflowmultiple
    else:
        raise ValueError('Unknown dataset type: {}'.format(configs.dataset))

    report_file = open(os.path.join(configs.output_folder, 'report.txt'), 'a')

    for config in configs:
        config_name = hashlib.md5(str(config).encode('utf-8')).hexdigest()
        print('Running', config_name)

        with open(
                os.path.join(configs.output_folder,
                             config_name + '.params.txt'), 'w') as f:
            f.write(pprint.pformat(config))

        train_loader = torch.utils.data.DataLoader(
            dataset_train, **config['TRAIN_DATA_LOADER'])
        valid_loader = torch.utils.data.DataLoader(
            dataset_valid, **config['VALID_DATA_LOADER'])

        model_class = config['MODEL']['model']
        model_params = {
            k: v
            for (k, v) in config['MODEL'].items() if k != 'model'
        }
        model = model_class(**model_params)
        if config['TRAINING'].get('_model_state_dict', None) is not None:
            model.load_weights(config['TRAINING']['_model_state_dict'])
        model.train(config['TRAINING']['train_mode'])
        model.cuda()

        loss_class = config['LOSS']['loss']
        loss_params = {
            k: v
            for (k, v) in config['LOSS'].items() if k != 'loss'
        }
        loss_fn = loss_class(**loss_params)

        model_weights = []
        for i in range(10):
            group_name = '_group_' + str(i)
            if group_name + '_params' in config['OPTIMIZER']:
                model_weights_group = {
                    'params':
                    model.get_training_parameters(
                        name=config['OPTIMIZER'][group_name + '_params'])
                }
                for k in config['OPTIMIZER']:
                    if k.startswith(
                            group_name) and k != group_name + '_params':
                        model_weights_group[k[9:]] = config['OPTIMIZER'][k]
                model_weights.append(model_weights_group)
        if len(model_weights) == 0:
            model_weights = model.get_training_parameters()

        optimizer_class = config['OPTIMIZER']['optimizer']
        optimizer_params = {
            k: v
            for (k, v) in config['OPTIMIZER'].items()
            if k != 'optimizer' and not k.startswith('_')
        }
        optimizer = optimizer_class(model_weights, **optimizer_params)

        scheduler_class = config['SCHEDULER']['scheduler']
        scheduler_params = {
            k: v
            for (k, v) in config['SCHEDULER'].items()
            if k != 'scheduler' and not k.startswith('_')
        }
        scheduler_lr = scheduler_class(optimizer, **scheduler_params)

        model_state_dict_path = os.path.join(configs.output_folder,
                                             config_name + '.model')
        logfile = open(
            os.path.join(configs.output_folder, config_name + '.log.txt'), 'w')

        result = train_model(model=model,
                             train_loader=train_loader,
                             valid_loader=valid_loader,
                             forward_fn=forward_fn,
                             loss_fn=loss_fn,
                             optimizer=optimizer,
                             scheduler_lr=scheduler_lr,
                             model_state_dict_path=model_state_dict_path,
                             logfile=logfile,
                             **{
                                 k: v
                                 for (k, v) in config['TRAINING'].items()
                                 if not k.startswith('_')
                             })

        max_valid_accuracy_idx = np.argmax(result['accuracies_valid'])

        print(
            '{} | Train Loss {:04.2f} | Train Accuracy = {:05.2f}% | Valid Loss {:04.2f} | Valid Accuracy = {:05.2f}%'
            .format(config_name,
                    result['losses_train'][max_valid_accuracy_idx],
                    result['accuracies_train'][max_valid_accuracy_idx] * 100,
                    result['losses_valid'][max_valid_accuracy_idx],
                    result['accuracies_valid'][max_valid_accuracy_idx] * 100),
            file=report_file)