示例#1
0
    def forward(self, indices: torch.LongTensor):
        '''
        Arguments:
            indices {torch.LongTensor} -- indices to self.path_list
        '''
        device = indices.device

        model = r2plus1d_18(pretrained=True).to(device)
        model.eval()
        # save the pre-trained classifier for show_preds and replace it in the net with identity
        model_class = model.fc
        model.fc = torch.nn.Identity()

        for idx in indices:
            # when error occurs might fail silently when run from torch data parallel
            try:
                feats_dict = self.extract(device, model, model_class,
                                          self.path_list[idx])
                action_on_extraction(feats_dict, self.path_list[idx],
                                     self.output_path, self.on_extraction)
            except KeyboardInterrupt:
                raise KeyboardInterrupt
            except Exception as e:
                # prints only the last line of an error. Use `traceback.print_exc()` for the whole traceback
                # traceback.print_exc()
                print(e)
                print(
                    f'Extraction failed at: {self.path_list[idx]} with error (↑). Continuing extraction'
                )

            # update tqdm progress bar
            self.progress.update()
示例#2
0
def load_r2plus1d_18_net(parallel=False):
  r2plus1d_18_model = r2plus1d_18(pretrained=True)
  r2plus1d_18_model = WrapR2plus1d_18(r2plus1d_18_model.eval()).cuda()
  if parallel:
    print('Parallelizing Inception module...')
    r2plus1d_18_model = nn.DataParallel(r2plus1d_18_model)
  return r2plus1d_18_model
    def __init__(self, args):
        super().__init__()
        # Initialize Stem to adjust number of channels
        self.stem = nn.Conv3d(1, 3, (1, 3, 3), stride=1, padding=(0, 1, 1))
        # Initialize 3D Resnet Model
        self.resnet3d = None
        if args.resnet3d_model == "r3d_18":
            # 18 layer Resnet3D
            self.resnet3d = r3d_18(pretrained=True)
        elif args.resnet3d_model == "mc3_18":
            # 18 layer Mixed Convolution network
            self.resnet3d = mc3_18(pretrained=True)
        else:
            # 18 layer deep R(2+1)D network
            self.resnet3d = r2plus1d_18(pretrained=True)

        self.resnet3d_out_features = self.resnet3d.fc.out_features

        self.features = args.features

        # # FC layers between resnet3d and the heads
        self.x1 = nn.Linear(self.resnet3d_out_features,
                            self.resnet3d_out_features)
        nn.init.kaiming_normal_(self.x1.weight)
        self.dropout1 = nn.Dropout(p=0.2)
        self.x2 = nn.Linear(self.resnet3d_out_features,
                            self.resnet3d_out_features // 2)
        nn.init.kaiming_normal_(self.x2.weight)
        self.dropout2 = nn.Dropout(p=0.2)

        for feature in self.features:
            setattr(self, f"{feature}_head",
                    ClassifierHead(self.resnet3d_out_features // 2, 1))
示例#4
0
def resnet3d(num_classes, expansion=False, maxpool=False):
    """

    Args:
        num_classes (int):

    Returns:
        torch.nn.modules.module.Module

    """

    model = r2plus1d_18(pretrained=False, progress=True)
    num_features = model.fc.in_features
    if expansion:
        model.fc = nn.Sequential(
            OrderedDict([
                ('dense', nn.Linear(in_features=num_features,
                                    out_features=200)),
                ('norm', nn.BatchNorm1d(num_features=200)),
                ('relu', nn.ReLU()), ('dropout', nn.Dropout(p=0.25)),
                ('last', nn.Linear(in_features=200, out_features=num_classes))
            ]))
    else:
        model.fc = nn.Linear(num_features, num_classes, bias=True)
    if maxpool:
        model.avgpool = nn.AdaptiveMaxPool3d(output_size=(1, 1, 1))

    return model
示例#5
0
 def __init__(
     self,
     vis: bool,
     disp_fig: bool,
     include_model: bool,
     show_gpu_utilization: bool,
     max_clips: int,
     imsz: int,
     loader_type: str,
     loader: Iterable,
     logger: logging.Logger,
 ):
     self.vis = vis
     self.imsz = imsz
     self.loader = loader
     self.disp_fig = disp_fig
     self.loader_type = loader_type
     self.include_model = include_model
     self.show_gpu_utilization = show_gpu_utilization
     self.max_clips = max_clips
     self.device = ("cuda" if torch.cuda.is_available() else "cpu")
     if self.include_model:
         self.model = r2plus1d_18(pretrained=False, progress=True)
         self.model = self.model.to(self.device)
     self.logger = logger
     self.logger.info(
         f"{loader_type} profiler, include_model: {self.include_model}")
    def __init__(self, num_classes=2, sequence_length=8, pretrained=True):
        super().__init__(num_classes=num_classes,
                         sequence_length=sequence_length)

        self.ff_sync_net = FFSyncNet(pretrained=pretrained)

        self.ff_sync_net.r2plus1 = r2plus1d_18(pretrained=pretrained)
        # self.ff_sync_net.r2plus1.layer2 = nn.Identity()
        self.ff_sync_net.r2plus1.layer3 = nn.Identity()
        self.ff_sync_net.r2plus1.layer4 = nn.Identity()
        self.ff_sync_net.r2plus1.fc = nn.Identity()

        self.ff_sync_net.video_mlp = nn.Sequential(nn.Linear(128, 128),
                                                   nn.BatchNorm1d(128),
                                                   nn.ReLU(),
                                                   nn.Linear(128, 1024))

        self.out = nn.Sequential(
            # nn.Dropout(p=0.5),
            nn.Linear(1024 * 2, 50),
            # nn.BatchNorm1d(50),
            # nn.Dropout(p=0.5),
            # nn.ReLU(),
            nn.LeakyReLU(0.02),
            nn.Linear(50, 2),
        )
示例#7
0
    def __init__(self, way=5, shot=1, query=5):
        super(R2Plus1D, self).__init__()
        self.way = way
        self.shot = shot
        self.query = query

        # r2plus1d_18
        model = r2plus1d_18(pretrained=True)
        
        # encoder(freezing)
        self.encoder_freeze = nn.Sequential(
            model.stem,
            model.layer1,
            model.layer2,
            model.layer3,
        )
        self.encoder_freeze.apply(freeze_all)

        # encoder(fine-tuning target)
        self.encoder_tune = nn.Sequential(
            model.layer4,
            nn.AdaptiveAvgPool3d(output_size=(1, 1, 1))
        )

        # scaler
        self.scaler = nn.Parameter(torch.tensor(5.0))
示例#8
0
    def __init__(self, num_classes):
        super().__init__(num_classes=2)
        self.r2plus1 = r2plus1d_18(pretrained=True)
        self.r2plus1.fc = nn.Identity()
        self._set_requires_grad_for_module(self.r2plus1, requires_grad=False)

        self.out = nn.Sequential(nn.Linear(512 + 1024, 50), nn.ReLU(),
                                 nn.Linear(50, self.num_classes))
示例#9
0
    def __init__(self, num_classes):
        super().__init__(num_classes=2)
        self.r2plus1 = r2plus1d_18(pretrained=True)

        self.r2plus1.layer4 = nn.Identity()
        self.r2plus1.fc = nn.Identity()

        self.out = nn.Sequential(nn.Linear(256, 50), nn.ReLU(),
                                 nn.Linear(50, self.num_classes))
示例#10
0
def Backbone_R2Plus1D18_Custumed(in_C):
    assert in_C == 3
    model = r2plus1d_18(pretrained=True, progress=True)
    div_2 = model.stem
    div_4 = model.layer2
    div_8 = model.layer3
    div_16 = model.layer4

    return div_2, div_4, div_8, div_16
示例#11
0
 def __init__(self, cfg, args, tok: BertTokenizer):
     super().__init__(cfg)
     self.clip_embeddings = r2plus1d_18(pretrained=args.from_pretrained)
     self.clip_embeddings.fc = nn.Linear(in_features=512, out_features=cfg.hidden_size)
     self.LayerNorm = BertLayerNorm(cfg.hidden_size, eps=cfg.layer_norm_eps)
     self.dropout = nn.Dropout(cfg.hidden_dropout_prob)
     self.tok = tok
     self.args = args
     if args.fixed_position_embeddings:
         self.position_embeddings = PositionEmbeddings(cfg.hidden_size, cfg.hidden_dropout_prob)
示例#12
0
def main(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if torch.cuda.is_available():
        print("Using CUDA, benchmarking implementations", file=sys.stderr)
        torch.backends.cudnn.benchmark = True

    # r2d2 says "beep beep"
    resnet = r2plus1d_18(pretrained=True, progress=False)

    resnet.fc = nn.Identity()
    # resnet.avgpool = nn.Identity()

    for params in resnet.parameters():
        params.requires_grad = False

    resnet = resnet.to(device)
    resnet = nn.DataParallel(resnet)

    resnet.eval()

    # Pre-trained Kinetics-400 statistics for normalization
    mean, std = [0.43216, 0.394666, 0.37645], [0.22803, 0.22145, 0.216989]

    mean = rearrange(torch.as_tensor(mean), "n -> () n () ()")
    std = rearrange(torch.as_tensor(std), "n -> () n () ()")

    video = vread(str(args.video))

    with torch.no_grad():
        for i, batch in enumerate(batched(video, args.timesteps)):
            # TODO:
            # - encapsulate video dataset
            # - abstract away transforms
            # - fix timesteps vs batching

            batch = rearrange(batch, "t h w c -> t c h w")
            batch = torch.tensor(batch)
            batch = batch.to(torch.float32) / 255

            batch = (batch - mean) / std

            # model expects NxCxTxHxW
            inputs = rearrange(batch, "t c h w -> () c t h w")
            inputs = inputs.to(device)

            outputs = resnet(inputs)
            outputs = rearrange(outputs, "() n -> n")
            outputs = outputs.data.cpu().numpy()

            print("seq={}, frames=range({}, {}), prediction={}".format(
                i, i * args.timesteps, (i + 1) * args.timesteps,
                outputs.shape))
示例#13
0
 def __init__(self, num_classes=5, sequence_length=8, contains_dropout=False):
     super().__init__(
         num_classes=num_classes,
         sequence_length=sequence_length,
         contains_dropout=contains_dropout,
     )
     self.r2plus1 = r2plus1d_18(pretrained=True)
     self.r2plus1.layer2 = nn.Identity()
     self.r2plus1.layer3 = nn.Identity()
     self.r2plus1.layer4 = nn.Identity()
     self.r2plus1.fc = nn.Sequential(
         nn.Linear(64, 512), nn.ReLU(), nn.Linear(512, self.num_classes)
     )
示例#14
0
    def __init__(self, num_classes=5, pretrained=True):
        super().__init__(num_classes=2, sequence_length=8, contains_dropout=False)
        self.r2plus1 = r2plus1d_18(pretrained=pretrained)
        self.r2plus1.fc = nn.Identity()
        self._set_requires_grad_for_module(self.r2plus1, requires_grad=False)

        self.resnet = resnet18(pretrained=pretrained, num_classes=1000)
        self.resnet.fc = nn.Identity()

        self.relu = nn.ReLU()
        self.out = nn.Sequential(
            nn.Linear(1024, 50), nn.ReLU(), nn.Linear(50, self.num_classes)
        )
        self._init = False
示例#15
0
    def __init__(self, pretrained):
        super().__init__()
        self.r2plus1 = r2plus1d_18(pretrained=pretrained)
        self.r2plus1.layer2 = nn.Identity()
        self.r2plus1.layer3 = nn.Identity()
        self.r2plus1.layer4 = nn.Identity()
        self.r2plus1.fc = nn.Identity()  # output is 64

        self.video_pooling = nn.Sequential(
            SqueezeModule(1, squeeze=False, dim=1),
            nn.MaxPool1d(8, 8),
            SqueezeModule(1, squeeze=True, dim=1),
            nn.Dropout(0.3),
        )
示例#16
0
    def __init__(self, num_classes):
        super().__init__(num_classes=2)
        self.r2plus1 = r2plus1d_18(pretrained=True)
        self.r2plus1.fc = nn.Identity()
        self._set_requires_grad_for_module(self.r2plus1, requires_grad=False)

        self.sync_net = resnet18(pretrained=True, num_classes=1000)
        self.sync_net.fc = nn.Identity()
        self.relu = nn.Identity()

        self.out = nn.Sequential(
            nn.Linear(512 + 512, 50),
            nn.BatchNorm1d(50),
            nn.LeakyReLU(0.2),
            nn.Linear(50, self.num_classes),
        )
    def __init__(self, num_classes=5, pretrained=True):
        super().__init__(num_classes=2,
                         sequence_length=8,
                         contains_dropout=False)
        self.similarity_net = SimilarityNetBigFiltered(num_classes=5,
                                                       sequence_length=8)
        self.video_extractor = r2plus1d_18(pretrained=True)
        self.video_extractor.fc = nn.Identity()
        self._set_requires_grad_for_module(self.video_extractor,
                                           requires_grad=False)

        self.out = nn.Sequential(
            nn.Linear(512 + 512, 50),
            nn.BatchNorm1d(50),
            nn.LeakyReLU(0.2),
            nn.Linear(50, self.num_classes),
        )
示例#18
0
    def __init__(self, num_classes=5, pretrained=True):
        super().__init__(num_classes=num_classes,
                         sequence_length=8,
                         contains_dropout=False)
        self.r2plus1 = r2plus1d_18(pretrained=pretrained)

        self.r2plus1.layer2 = nn.Identity()
        self.r2plus1.layer3 = nn.Identity()
        self.r2plus1.layer4 = nn.Identity()
        self.r2plus1.fc = nn.Identity()

        self.sync_net = PretrainedSyncNet()
        self._set_requires_grad_for_module(self.sync_net, requires_grad=False)

        self.relu = nn.ReLU()
        self.out = nn.Sequential(nn.Linear(64 + 1024, 50), nn.ReLU(),
                                 nn.Linear(50, self.num_classes))
    def __init__(self, num_classes=5, sequence_length=8, pretrained=True):
        super().__init__(num_classes=num_classes,
                         sequence_length=sequence_length)

        # self.r2plus1 = nn.Sequential(
        #     r2plus1d_18(pretrained=pretrained),
        #     nn.Linear(512, 512),
        #     nn.BatchNorm1d(512),
        #     nn.LeakyReLU(0.2),
        # )
        # self.r2plus1[0].fc = nn.Identity()
        self.r2plus1 = r2plus1d_18(pretrained=pretrained)
        self.r2plus1.fc = nn.Identity()

        # self.audio_extractor = nn.Sequential(
        #     resnet18(pretrained=pretrained, num_classes=1000),
        #     nn.Linear(512, 512),
        #     nn.BatchNorm1d(512),
        #     nn.LeakyReLU(0.2),
        # )
        # self.audio_extractor[0].fc = nn.Identity()
        self.audio_extractor = resnet18(pretrained=pretrained,
                                        num_classes=1000)
        self.audio_extractor.fc = nn.Identity()

        self.filter = nn.Sequential(  # b x 512 x 9
            nn.Conv1d(512, 128, kernel_size=3, stride=1, padding=1,
                      bias=True),  # b x 16 x seq_len
            nn.LeakyReLU(0.02, True),
            nn.Conv1d(128, 32, kernel_size=3, stride=1, padding=1,
                      bias=True),  # b x 8 x seq_len
            nn.LeakyReLU(0.02, True),
            nn.Conv1d(32, 8, kernel_size=3, stride=1, padding=1,
                      bias=True),  # b x 4 x seq_len
            nn.LeakyReLU(0.02, True),
            nn.Conv1d(8, 2, kernel_size=3, stride=1, padding=1,
                      bias=True),  # b x 2 x seq_len
            nn.LeakyReLU(0.02, True),
            nn.Conv1d(2, 1, kernel_size=3, stride=1, padding=1,
                      bias=True),  # b x 1 x seq_len
            nn.LeakyReLU(0.02, True),
        )

        self.attention = self.attentionNet = nn.Sequential(
            nn.Linear(9, 9, bias=True), nn.Softmax(dim=1))
示例#20
0
    def __init__(self, num_classes=5, pretrained=True):
        super().__init__(
            num_classes=num_classes, sequence_length=8, contains_dropout=False
        )
        self.r2plus1 = r2plus1d_18(pretrained=pretrained)
        self.r2plus1.layer3 = nn.Identity()
        self.r2plus1.layer4 = nn.Identity()
        self.r2plus1.fc = nn.Identity()

        self.resnet = resnet18(pretrained=pretrained, num_classes=1000)
        self.resnet.layer3 = nn.Identity()
        self.resnet.layer4 = nn.Identity()
        self.resnet.fc = nn.Identity()

        self.relu = nn.ReLU()
        self.out = nn.Sequential(
            nn.Linear(256, 50), nn.ReLU(), nn.Linear(50, self.num_classes)
        )
示例#21
0
def r2plus1d_18(*, num_classes, pretrained=False, progress=True, **kwargs):
    '''
    Use pretrained model except fc layer, so that we can use different num_classes.
    '''
    model = video.r2plus1d_18(pretrained=False,
                              progress=False,
                              num_classes=num_classes,
                              **kwargs)
    if pretrained:
        arch = 'r2plus1d_18'
        state_dict = load_state_dict_from_url(video.resnet.model_urls[arch],
                                              progress=progress)
        del state_dict['fc.weight']
        del state_dict['fc.bias']

        incompatible_key = model.load_state_dict(state_dict, strict=False)
        assert set(incompatible_key.missing_keys) == set(
            ['fc.weight', 'fc.bias'])
    return model
示例#22
0
 def __init__(self, num_classes=2, sequence_length=8, contains_dropout=False):
     super(BinaryEvaluationMixin, self).__init__(
         num_classes=2,
         sequence_length=sequence_length,
         contains_dropout=contains_dropout,
     )
     self.r2plus1 = r2plus1d_18(pretrained=True)
     self.r2plus1.layer3 = nn.Identity()
     self.r2plus1.layer4 = nn.Identity()
     self.r2plus1.fc = nn.Sequential(
         nn.Linear(128, 128),
         # nn.BatchNorm1d(128),
         nn.ReLU(),
         nn.Linear(128, 1024),
         nn.ReLU(),
         nn.Linear(1024, 50),
         # nn.LeakyReLU(0.02),
         nn.ReLU(),
         nn.Linear(50, self.num_classes),
     )
示例#23
0
    def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):
        super(EncoderDecoder, self).__init__()
        self.convnet = nn.Sequential(
            *list(r2plus1d_18(pretrained=True).children())[:-1])

        #We define which parameters to train
        for layer in self.convnet:
            for param in layer.parameters():
                param.requires_grad = False

        #for layer in self.convnet[4][1]:
        for param in self.convnet[4][1].parameters():
            param.requires_grad = True
        self.len = 30
        self.intermediate = nn.Linear(512, 128)
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.generator = generator
    def __init__(self, num_classes=101):
        super(R2Plus1D, self).__init__()

        # encoder(r2plus1d18)
        model = r2plus1d_18(pretrained=True)
        self.encoder_freeze = nn.Sequential(
            model.stem,
            model.layer1,
            model.layer2,
            model.layer3,
        )
        self.encoder_freeze.apply(freeze_all)

        self.encoder_tune = nn.Sequential(
            model.layer4,
            nn.AdaptiveAvgPool3d(output_size=(1, 1, 1)),
        )

        # classifier
        self.classifier = nn.Linear(model.fc.in_features, num_classes)
        self.classifier.apply(initialize_linear)
    def __init__(self, num_classes=5, sequence_length=8, pretrained=True):
        super().__init__(
            num_classes=num_classes,
            sequence_length=sequence_length,
            contains_dropout=False,
        )
        self.r2plus1 = r2plus1d_18(pretrained=pretrained)
        self.r2plus1.layer2 = nn.Identity()
        self.r2plus1.layer3 = nn.Identity()
        self.r2plus1.layer4 = nn.Identity()
        self.r2plus1.fc = nn.Identity()

        self.audio_extractor = resnet18(pretrained=pretrained, num_classes=1000)
        self.audio_extractor.layer2 = nn.Identity()
        self.audio_extractor.layer3 = nn.Identity()
        self.audio_extractor.layer4 = nn.Identity()
        self.audio_extractor.fc = nn.Identity()

        self.c_loss = ContrastiveLoss(20)

        self.log_class_loss = False
示例#26
0
    def __init__(self, num_classes=5, sequence_length=8, pretrained=True):
        super().__init__(
            num_classes=num_classes,
            sequence_length=sequence_length,
            contains_dropout=False,
        )
        self.r2plus1 = r2plus1d_18(pretrained=pretrained)
        self.r2plus1.layer2 = nn.Identity()
        self.r2plus1.layer3 = nn.Identity()
        self.r2plus1.layer4 = nn.Identity()
        self.r2plus1.fc = nn.Identity()

        self.video_mlp = nn.Sequential(
            nn.Linear(64, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Linear(512, 1024)
        )

        self.sync_net = PretrainedSyncNet()
        self._set_requires_grad_for_module(self.sync_net, requires_grad=False)

        self.audio_extractor = self.sync_net.audio_extractor

        self.c_loss = ContrastiveLoss(20)

        self.log_class_loss = False
示例#27
0
    def __init__(self, num_classes=5, pretrained=True):
        super().__init__(num_classes=num_classes,
                         sequence_length=8,
                         contains_dropout=False)
        self.r2plus1 = r2plus1d_18(pretrained=True)

        self.r2plus1.layer3 = nn.Identity()
        self.r2plus1.layer4 = nn.Identity()
        self.r2plus1.fc = nn.Identity()

        self.sync_net = PretrainedSyncNet()
        self._set_requires_grad_for_module(self.sync_net, requires_grad=False)

        self.relu = nn.ReLU()

        self.padding = nn.ReflectionPad2d((0, 1, 0, 0))
        self.upsample = nn.Upsample(size=(8, 56, 56))

        self.merge_conv: nn.Module = nn.Sequential(
            Conv2Plus1D(128, 64, 144, 1), nn.BatchNorm3d(64),
            nn.ReLU(inplace=True))

        self.out = nn.Sequential(nn.Linear(128, 50), nn.ReLU(),
                                 nn.Linear(50, self.num_classes))
示例#28
0
 def test_r2plus1d_18_video(self):
     x = Variable(torch.randn(1, 3, 4, 112, 112).fill_(1.0))
     self.exportTest(toC(r2plus1d_18()), toC(x), rtol=1e-3, atol=1e-5)
    def __init__(self, experiment, device):
        config_file = os.path.join(CONFIG_DIR, experiment + '.json')
        assert os.path.exists(
            config_file), 'config file {} does not exist'.format(config_file)
        self.experiment = experiment
        with open(config_file, 'r') as f:
            configs = json.load(f)
        self.device = int(device)

        self.lr = configs['lr']
        self.max_epochs = configs['max-epochs']
        self.train_batch_size = configs['train-batch-size']
        self.test_batch_size = configs['test-batch-size']
        self.n_epochs = 0
        self.n_test_segments = configs['n-test-segments']

        self.log_dir = os.path.join(LOG_DIR, experiment)
        if not os.path.exists(self.log_dir):
            os.makedirs(self.log_dir)
        self.tboard_writer = tensorboardX.SummaryWriter(log_dir=self.log_dir)

        self.checkpoint_dir = os.path.join(CHECKPOINT_DIR, experiment)
        if not os.path.exists(self.checkpoint_dir):
            os.makedirs(self.checkpoint_dir)

        model_id = configs['model-id']
        if model_id == 'r3d':
            self.model = models.r3d_18(pretrained=True)
        elif model_id == 'mc3':
            self.model = models.mc3_18(pretrained=True)
        elif model_id == 'r2plus1d':
            self.model = models.r2plus1d_18(pretrained=True)
        else:
            raise ValueError('no such model')
        # replace the last layer.
        self.model.fc = nn.Linear(self.model.fc.in_features,
                                  out_features=breakfast.N_CLASSES,
                                  bias=self.model.fc.bias is not None)
        self.model = self.model.cuda(self.device)
        self.loss_fn = nn.CrossEntropyLoss().cuda(self.device)
        if configs['optim'] == 'adam':
            self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
        elif configs['optim'] == 'sgd':
            self.optimizer = optim.SGD(self.model.parameters(),
                                       lr=self.lr,
                                       momentum=configs['momentum'],
                                       nesterov=configs['nesterov'])
        else:
            raise ValueError('no such optimizer')

        if configs['scheduler'] == 'step':
            self.scheduler = optim.lr_scheduler.StepLR(
                self.optimizer,
                step_size=configs['lr-step'],
                gamma=configs['lr-decay'])
        elif configs['scheduler'] == 'plateau':
            self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
                self.optimizer, mode='min', patience=configs['lr-step'])
        else:
            raise ValueError('no such scheduler')
        self._load_checkpoint()
        self.frame_stride = configs['frame-stride']
def main(model_name, 
         mode,
         root,
         val_split,
         ckpt,
         batch_per_gpu):
    num_gpus = MPI.COMM_WORLD.Get_size()
    distributed = False
    if num_gpus > 1:
        distributed = True

    local_rank = MPI.COMM_WORLD.Get_rank() % torch.cuda.device_count()

    if distributed:
        torch.cuda.set_device(local_rank)
        host = os.environ["MASTER_ADDR"] if "MASTER_ADDR" in os.environ else "127.0.0.1"
        torch.distributed.init_process_group(
            backend="nccl",
            init_method='tcp://{}:12345'.format(host),
            rank=MPI.COMM_WORLD.Get_rank(),
            world_size=MPI.COMM_WORLD.Get_size()
        )

        synchronize()

    val_dataloader = make_dataloader(root,
                                        val_split, 
                                        mode,
                                        model_name,
                                        seq_len=16, #64, 
                                        overlap=8, #32,
                                        phase='val', 
                                        max_iters=None, 
                                        batch_per_gpu=batch_per_gpu,
                                        num_workers=16, 
                                        shuffle=False, 
                                        distributed=distributed,
                                        with_normal=False)

    if model_name == 'i3d':
        if mode == 'flow':
            model = InceptionI3d(val_dataloader.dataset.num_classes, in_channels=2, dropout_keep_prob=0.5)
        else:
            model = InceptionI3d(val_dataloader.dataset.num_classes, in_channels=3, dropout_keep_prob=0.5)
        model.replace_logits(val_dataloader.dataset.num_classes)
    elif model_name == 'r3d_18':
        model = r3d_18(pretrained=False, num_classes=val_dataloader.dataset.num_classes)
    elif model_name == 'mc3_18':
        model = mc3_18(pretrained=False, num_classes=val_dataloader.dataset.num_classes)
    elif model_name == 'r2plus1d_18':
        model = r2plus1d_18(pretrained=False, num_classes=val_dataloader.dataset.num_classes)
    elif model_name == 'c3d':
        model = C3D(pretrained=False, num_classes=val_dataloader.dataset.num_classes)
    else:
        raise NameError('unknown model name:{}'.format(model_name))

    # pdb.set_trace()
    for param in model.parameters():
        pass
    
    device = torch.device('cuda')
    model.to(device)
    if distributed:
        model = apex.parallel.convert_syncbn_model(model)
        model = DDP(model.cuda(), delay_allreduce=True)