Example #1
0
    def init_network(self, args, num_classes):
        # Base semantic segmentation network
        self.segnet = self.init_segnet(args, num_classes)
        self.segnet.to(self.device)
        self.segnet.eval()

        # Motion-compensated layer
        self.mc_layer = MC_Module_Batch(train_noise_sigma=False,
                                        eval_noise_sigma=False)
        self.mc_layer.to(self.device)
        self.mc_layer.eval()

        # Optical flow network
        self.ofnet = self.init_optical_flow(args)
        self.ofnet.eval()

        if self.distributed:
            self.ofnet = self.ofnet.to(self.device)
            print("=> rank %i: initializing of" % self.local_rank)
            self.ofnet = DistributedDataParallel(self.ofnet)
            dist.barrier()
        else:
            self.ofnet = nn.DataParallel(self.ofnet.to(self.device))

        return End2End(
            GuidedSpatiallyVaryingConv(freeze_bn=args.freeze_bn,
                                       in_class=num_classes,
                                       out_class=num_classes,
                                       num_filter=32,
                                       batch_norm=nn.BatchNorm2d,
                                       is_train_filter=args.train_filter))
def build_and_initialize_model_and_optimizer(block, args):
    model = args.network_class(args)
    block.log('Number of parameters: {val:,}'.format(val=sum([
        p.data.nelement() if p.requires_grad else 0
        for p in model.parameters()
    ])))

    block.log('Initializing CUDA')
    assert torch.cuda.is_available(), 'only GPUs support at the moment'
    model.cuda(torch.cuda.current_device())

    optimizer = args.optimizer_class(
        [p for p in model.parameters() if p.requires_grad], lr=args.lr)

    block.log("Attempting to Load checkpoint '{}'".format(args.resume))
    if args.resume and os.path.isfile(args.resume):
        load_model(model, optimizer, block, args)
    elif args.resume:
        block.log("No checkpoint found at '{}'".format(args.resume))
        exit(1)
    else:
        block.log("Random initialization, checkpoint not provided.")
        args.start_epoch = 0

    if args.fp16:
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')

    # Run multi-process when it is needed.
    if args.world_size > 1:
        model = DistributedDataParallel(model)

    return model, optimizer
Example #3
0
    def _make_parallel(runner, net):
        if runner.configer.get('network.distributed', default=False):
            #print('n1')
            from apex.parallel import DistributedDataParallel
            #print('n2')
            if runner.configer.get('network.syncbn', default=False):
                Log.info('Converting syncbn model...')
                from apex.parallel import convert_syncbn_model
                net = convert_syncbn_model(net)

            torch.cuda.set_device(runner.configer.get('local_rank'))
            torch.distributed.init_process_group(backend='nccl',
                                                 init_method='env://')
            net = DistributedDataParallel(net.cuda(), delay_allreduce=True)
            return net

        net = net.to(
            torch.device(
                'cpu' if runner.configer.get('gpu') is None else 'cuda'))
        if len(runner.configer.get('gpu')) > 1:
            from exts.tools.parallel.data_parallel import ParallelModel
            return ParallelModel(net,
                                 gather_=runner.configer.get(
                                     'network', 'gather'))

        return net
Example #4
0
 def init_distributed(self, rank, local_rank):
     assert not self.distribured_enabled
     self.distribured_enabled = True
     print("Initializing Distributed, rank {}, local rank {}".format(rank, local_rank))
     dist.init_process_group(backend='nccl', rank=rank)
     torch.cuda.set_device(local_rank)
     self.core = DistributedDataParallel(self.core)
Example #5
0
    def init_from_checkpoint(self, continue_state_object):
        t_start = time.time()

        self.config = continue_state_object['config']
        self._build_environ()
        self.model = _get_model(self.config)
        self.filtered_keys = [
            p.name
            for p in inspect.signature(self.model.forward).parameters.values()
        ]
        # model_params = filter(lambda p: p.requires_grad, self.model.parameters())
        model_params = []
        for params in self.model.optimizer_params():
            params["lr"] = self.config["solver"]["optimizer"]["params"][
                "lr"] * params["lr"]
            model_params.append(params)
        self.optimizer = _get_optimizer(self.config['solver']['optimizer'],
                                        model_params=model_params)
        self.lr_policy = _get_lr_policy(self.config['solver']['lr_policy'],
                                        optimizer=self.optimizer)

        load_model(self.model,
                   continue_state_object['model'],
                   distributed=False)
        self.model.cuda(self.local_rank)

        if self.distributed:
            self.model = convert_syncbn_model(self.model)

        if self.config['apex']['amp_used']:
            # Initialize Amp.  Amp accepts either values or strings for the optional override arguments,
            # for convenient interoperation with argparse.
            logging.info(
                "Initialize Amp. opt level={}, keep batchnorm fp32={}, loss_scale={}."
                .format(self.config['apex']['opt_level'],
                        self.config['apex']['keep_batchnorm_fp32'],
                        self.config['apex']['loss_scale']))
            self.model, self.optimizer = amp.initialize(
                self.model,
                self.optimizer,
                opt_level=self.config['apex']['opt_level'],
                keep_batchnorm_fp32=self.config['apex']["keep_batchnorm_fp32"],
                loss_scale=self.config['apex']["loss_scale"])
            amp.load_state_dict(continue_state_object['amp'])

        if self.distributed:
            self.model = DistributedDataParallel(self.model)

        self.optimizer.load_state_dict(continue_state_object['optimizer'])
        self.lr_policy.load_state_dict(continue_state_object['lr_policy'])

        self.step_decay = self.config['solver']['step_decay']
        self.epoch = continue_state_object['epoch']
        self.iteration = continue_state_object["iteration"]

        del continue_state_object
        t_end = time.time()
        logging.info(
            "Init trainer from checkpoint, Time usage: IO: {}".format(t_end -
                                                                      t_start))
def main_worker(gpu, ngpus_per_node, args):
    dist.init_process_group(backend='nccl')
    torch.cuda.set_device(gpu)

    model = MyModel()
    model.cuda()

    args.batch_size = int(args.batch_size / ngpus_per_node)
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(model.parameters(), args.lr)

    model, optimizer = amp.initialize(model,
                                      optimizer)
    model = DistributedDataParallel(model)
    cudnn.benchmark = True

    train_dataset = MyDataset()

    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset,
                                                                    num_replicas=hvd.size(),
                                                                    rank=hvd.rank())

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=2,
                                               pin_memory=True,
                                               sampler=train_sampler)

    train_loader2 = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=2,
                                               pin_memory=True)
    
    for epoch in range(5):
        train_sampler.set_epoch(epoch)
        model.train()

        for i, (data, label) in enumerate(train_loader):
            
            data = data.cuda(non_blocking=True)
            label = label.cuda(non_blocking=True)
            output = model(data)
            loss = criterion(output, label)

            # print('epoch', epoch, 'gpu', gpu)
            # params = list(model.named_parameters())
            # for i in range(len(params)):
            #     (name, param) = params[i]
            #     print(name)
            #     print(param.grad)

            print('epoch', epoch, 'iter', i, 'gpu', gpu)
            print(data)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
Example #7
0
    def __init__(self, cfg):
        """
        Args:
            cfg (CfgNode):
        """
        super().__init__()
        logger = logging.getLogger("fastreid")
        if not logger.isEnabledFor(
                logging.INFO):  # setup_logger is not called for fastreid
            setup_logger()

        # Assume these objects must be constructed in this order.
        data_loader = self.build_train_loader(cfg)
        cfg = self.auto_scale_hyperparams(cfg, data_loader.dataset.num_classes)
        model = self.build_model(cfg)
        optimizer = self.build_optimizer(cfg, model)

        optimizer_ckpt = dict(optimizer=optimizer)
        if cfg.SOLVER.FP16_ENABLED:
            model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
            optimizer_ckpt.update(dict(amp=amp))

        # For training, wrap with DDP. But don't need this for inference.
        if comm.get_world_size() > 1:
            # ref to https://github.com/pytorch/pytorch/issues/22049 to set `find_unused_parameters=True`
            # for part of the parameters is not updated.
            # model = DistributedDataParallel(
            #     model, device_ids=[comm.get_local_rank()], broadcast_buffers=False
            # )
            model = DistributedDataParallel(model, delay_allreduce=True)

        self._trainer = (AMPTrainer if cfg.SOLVER.FP16_ENABLED else
                         SimpleTrainer)(model, data_loader, optimizer)

        self.iters_per_epoch = len(
            data_loader.dataset) // cfg.SOLVER.IMS_PER_BATCH
        self.scheduler = self.build_lr_scheduler(cfg, optimizer,
                                                 self.iters_per_epoch)

        # Assume no other objects need to be checkpointed.
        # We can later make it checkpoint the stateful hooks
        self.checkpointer = Checkpointer(
            # Assume you want to save checkpoints together with logs/statistics
            model,
            cfg.OUTPUT_DIR,
            save_to_disk=comm.is_main_process(),
            **optimizer_ckpt,
            **self.scheduler,
        )

        self.start_epoch = 0
        self.max_epoch = cfg.SOLVER.MAX_EPOCH
        self.max_iter = self.max_epoch * self.iters_per_epoch
        self.warmup_iters = cfg.SOLVER.WARMUP_ITERS
        self.delay_epochs = cfg.SOLVER.DELAY_EPOCHS
        self.cfg = cfg

        self.register_hooks(self.build_hooks())
Example #8
0
def initTrain(opt, epoch=None, rank=0):
    model = Model(opt).train()
    if epoch:
        initParameters(opt, model)
        if type(epoch) == int:
            model.load_state_dict(
                torch.load(modelPath(epoch), map_location='cpu'))
    model = model.to(opt.device)  # need before constructing optimizers
    paramOptions = getParamOptions(opt, model)
    eps = 1e-4 if opt.fp16 else 1e-8
    opt.optimizer = opt.newOptimizer(opt, paramOptions, eps)
    if opt.fp16:
        model, opt.optimizer = amp.initialize(model,
                                              opt.optimizer,
                                              opt_level="O{}".format(opt.fp16),
                                              **opt.ampArgs)
    if opt.sdt_decay_step > 0:
        gamma = opt.gamma if hasattr(opt, 'gamma') else .5
        opt.scheduler = optim.lr_scheduler.StepLR(opt.optimizer,
                                                  opt.sdt_decay_step,
                                                  gamma=gamma)
    else:
        opt.scheduler = None
    if type(epoch) == int and os.path.isfile(statePath(epoch)):
        state = torch.load(statePath(epoch), map_location='cpu')
        opt.optimizer.load_state_dict(state[0])
        if opt.scheduler:
            opt.scheduler.load_state_dict(state[1])
        if opt.fp16 and len(state) > 2:
            amp.load_state_dict(state[2])
    if opt.mp:
        if opt.cuda:
            from apex.parallel import DistributedDataParallel, convert_syncbn_model
            model = DistributedDataParallel(convert_syncbn_model(model),
                                            message_size=getNelement(model) -
                                            1)
        else:
            from torch.nn.parallel import DistributedDataParallel
            model = DistributedDataParallel(model,
                                            device_ids=[opt.devices[rank]])
    if opt.cuda and rank == 0:
        print('GPU memory allocated before training: {} bytes'.format(
            torch.cuda.max_memory_allocated()))
        torch.cuda.reset_max_memory_allocated()
    return opt, model
Example #9
0
    def init_from_scratch(self, config):
        t_start = time.time()
        self.config = config
        self._build_environ()
        # model and optimizer
        self.model = _get_model(self.config)
        self.filtered_keys = [
            p.name
            for p in inspect.signature(self.model.forward).parameters.values()
        ]
        # logging.info("filtered keys:{}".format(self.filtered_keys))
        # model_params = filter(lambda p: p.requires_grad, self.model.parameters())
        model_params = []
        for params in self.model.optimizer_params():
            params["lr"] = self.config["solver"]["optimizer"]["params"][
                "lr"] * params["lr"]
            model_params.append(params)
        self.optimizer = _get_optimizer(config['solver']['optimizer'],
                                        model_params=model_params)

        self.lr_policy = _get_lr_policy(config['solver']['lr_policy'],
                                        optimizer=self.optimizer)
        self.step_decay = config['solver']['step_decay']

        if config['model'].get('pretrained_model') is not None:
            logging.info('loadding pretrained model from {}.'.format(
                config['model']['pretrained_model']))
            load_model(self.model,
                       config['model']['pretrained_model'],
                       distributed=False)

        self.model.cuda(self.local_rank)

        if self.distributed:
            self.model = convert_syncbn_model(self.model)

        if self.config['apex']['amp_used']:
            # Initialize Amp.  Amp accepts either values or strings for the optional override arguments,
            # for convenient interoperation with argparse.
            logging.info(
                "Initialize Amp. opt level={}, keep batchnorm fp32={}, loss_scale={}."
                .format(self.config['apex']['opt_level'],
                        self.config['apex']['keep_batchnorm_fp32'],
                        self.config['apex']['loss_scale']))
            self.model, self.optimizer = amp.initialize(
                self.model,
                self.optimizer,
                opt_level=self.config['apex']['opt_level'],
                keep_batchnorm_fp32=self.config['apex']["keep_batchnorm_fp32"],
                loss_scale=self.config['apex']["loss_scale"])
        if self.distributed:
            self.model = DistributedDataParallel(self.model)

        t_end = time.time()
        logging.info(
            "Init trainer from scratch, Time usage: IO: {}".format(t_end -
                                                                   t_start))
Example #10
0
def train_single_jigsaw(args, device_id):
    init_logger(args.log_file)

    # device = "cpu" if args.visible_gpus == '-1' else "cuda"
    device = 'cuda'
    # logger.info('Device ID %d' % device_id)
    # logger.info('Device %s' % device)
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    # torch.backends.cudnn.deterministic = True

    if device_id >= 0:
        # torch.cuda.set_device(device_id)
        torch.cuda.manual_seed(args.seed)

    torch.manual_seed(args.seed)
    random.seed(args.seed)
    # torch.backends.cudnn.deterministic = True

    if args.train_from != '':
        logger.info('Loading checkpoint from %s' % args.train_from)
        checkpoint = torch.load(args.train_from,
                                map_location=lambda storage, loc: storage)
        opt = vars(checkpoint['opt'])
        for k in opt.keys():
            if (k in model_flags):
                setattr(args, k, opt[k])
    else:
        checkpoint = None

    def train_iter_fct():
        return jigsaw_data_loader.Dataloader(args, jigsaw_data_loader.load_dataset(args, 'train', shuffle=True), args.batch_size, device,
                                      shuffle=True, is_test=False)
    jigsaw = args.jigsaw if 'jigsaw' in args else 'jigsaw_lab'
    if jigsaw == 'jigsaw_dec':
        model = SentenceTransformer(args, device, checkpoint, sum_or_jigsaw=1)
    else:
        model = Jigsaw(args, device, checkpoint)
    optim = build_optim(args, model, checkpoint)

    logger.info(model)
    if args.fp16:
        opt_level = 'O1'  # typical fp16 training, can also try O2 to compare performance
    else:
        opt_level = 'O0'  # pure fp32 traning
    model, optim.optimizer = amp.initialize(model, optim.optimizer, opt_level=opt_level)
    if args.distributed:
        # FOR DISTRIBUTED:  After amp.initialize, wrap the model with
        # apex.parallel.DistributedDataParallel.
        model = DistributedDataParallel(model)

    # logger.info('type(optim)'+str(type(optim)))
    trainer = build_trainer(args, device_id, model, optim)
    trainer.train(train_iter_fct, args.train_steps)
Example #11
0
def to_ddp(modules: Union[list, nn.Module],
           optimizer: torch.optim.Optimizer = None,
           opt_level: int = 0) -> Union[DistributedDataParallel, tuple]:
    if isinstance(modules, list):
        modules = [x.cuda() for x in modules]
    else:
        modules = modules.cuda()
    if optimizer is not None:
        modules, optimizer = amp.initialize(modules,
                                            optimizer,
                                            opt_level="O{}".format(opt_level),
                                            verbosity=1)
    if isinstance(modules, list):
        modules = [
            DistributedDataParallel(x, delay_allreduce=True) for x in modules
        ]
    else:
        modules = DistributedDataParallel(modules, delay_allreduce=True)
    if optimizer is not None:
        return modules, optimizer
    else:
        return modules
Example #12
0
def process_components(
    model: _Model,
    criterion: _Criterion = None,
    optimizer: _Optimizer = None,
    scheduler: _Scheduler = None,
    distributed_params: Dict = None
) -> Tuple[_Model, _Criterion, _Optimizer, _Scheduler, torch.device]:
    distributed_params = distributed_params or {}
    distributed_params = copy.deepcopy(distributed_params)
    device = utils.get_device()

    model = maybe_recursive_call(model, "to", device=device)

    if utils.is_wrapped_with_ddp(model):
        pass
    elif len(distributed_params) > 0:
        assert isinstance(model, nn.Module)
        utils.assert_fp16_available()
        from apex import amp
        from apex.parallel import convert_syncbn_model

        distributed_rank = distributed_params.pop("rank", -1)
        syncbn = distributed_params.pop("syncbn", False)

        if distributed_rank > -1:
            torch.cuda.set_device(distributed_rank)
            torch.distributed.init_process_group(
                backend="nccl", init_method="env://"
            )

        model, optimizer = amp.initialize(
            model, optimizer, **distributed_params
        )

        if distributed_rank > -1:
            from apex.parallel import DistributedDataParallel
            model = DistributedDataParallel(model)

            if syncbn:
                model = convert_syncbn_model(model)
        elif torch.cuda.device_count() > 1:
            model = torch.nn.DataParallel(model)
    elif torch.cuda.device_count() > 1:
        if isinstance(model, nn.Module):
            model = torch.nn.DataParallel(model)
        elif isinstance(model, dict):
            model = {k: torch.nn.DataParallel(v) for k, v in model.items()}

    model = maybe_recursive_call(model, "to", device=device)

    return model, criterion, optimizer, scheduler, device
    def make_parallel(runner, net, optimizer):
        if runner.configer.get('distributed', default=False):
            from apex.parallel import DistributedDataParallel
            if runner.configer.get('network.syncbn', default=False):
                Log.info('Converting syncbn model...')
                from apex.parallel import convert_syncbn_model
                net = convert_syncbn_model(net)
            torch.cuda.set_device(runner.configer.get('local_rank'))
            torch.distributed.init_process_group(backend='nccl', init_method='env://')
            if runner.configer.get('dtype') == 'fp16':
                from apex import amp
                net, optimizer = amp.initialize(net.cuda(), optimizer, opt_level="O1")
                net = DistributedDataParallel(net, delay_allreduce=True)
            else:
                assert runner.configer.get('dtype') == 'none'
                net = DistributedDataParallel(net.cuda(), delay_allreduce=True)
            return net, optimizer
        net = net.to(torch.device('cpu' if runner.configer.get('gpu') is None else 'cuda'))
        if len(runner.configer.get('gpu')) > 1:
            from lib.utils.parallel.data_parallel import DataParallelModel
            return DataParallelModel(net, gather_=runner.configer.get('network', 'gather')), optimizer

        return net, optimizer
Example #14
0
    def setup_model(self, trainer: Trainer):

        trainer.model = move_to_device(trainer.model, trainer.device)

        # FP16
        if self.config.training.fp16:
            trainer.model, trainer.optimizer = amp.initialize(
                trainer.model,
                trainer.optimizer,
                opt_level=self.config.training.fp16_opt_level)

        if self.config.training.num_gpus_per_node > 1:
            # Distributed training (should be after apex fp16 initialization)
            trainer.model = DistributedDataParallel(trainer.model,
                                                    delay_allreduce=True)
Example #15
0
    def __init__(self, conf, inference=False):
        accuracy = 0.0
        logger.debug(conf)
        if conf.use_mobilfacenet:
            # self.model = MobileFaceNet(conf.embedding_size).to(conf.device)
            self.model = MobileFaceNet(conf.embedding_size).cuda()
            logger.debug('MobileFaceNet model generated')
        else:
            self.model = Backbone(conf.net_depth, conf.drop_ratio, conf.net_mode).cuda()#.to(conf.device)
            logger.debug('{}_{} model generated'.format(conf.net_mode, conf.net_depth))
        if not inference:
            self.milestones = conf.milestones
            logger.info('loading data...')
            self.loader, self.class_num = get_train_loader(conf, 'emore', sample_identity=True)

            self.writer = SummaryWriter(conf.log_path)
            self.step = 0
            self.head = CircleLoss(m=0.25, gamma=256.0).cuda()

            paras_only_bn, paras_wo_bn = separate_bn_paras(self.model)

            if conf.use_mobilfacenet:
                self.optimizer = optim.SGD([
                                    {'params': paras_wo_bn[:-1], 'weight_decay': 4e-5},
                                    {'params': [paras_wo_bn[-1]], 'weight_decay': 4e-4},
                                    {'params': paras_only_bn}
                                ], lr = conf.lr, momentum = conf.momentum)
            else:
                self.optimizer = optim.SGD([
                                    {'params': paras_wo_bn, 'weight_decay': 5e-4},
                                    {'params': paras_only_bn}
                                ], lr = conf.lr, momentum = conf.momentum)
            # self.optimizer = torch.nn.parallel.DistributedDataParallel(optimizer,device_ids=[conf.argsed])
            # self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, patience=40, verbose=True)

            if conf.fp16:
                self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level="O2")
                self.model = DistributedDataParallel(self.model).cuda()
            else:
                self.model = torch.nn.parallel.DistributedDataParallel(self.model, device_ids=[conf.argsed]).cuda() #add line for distributed

            self.board_loss_every = len(self.loader)//100
            self.evaluate_every = len(self.loader)//2
            self.save_every = len(self.loader)//2
            self.agedb_30, self.cfp_fp, self.lfw, self.agedb_30_issame, self.cfp_fp_issame, self.lfw_issame = get_val_data(Path(self.loader.dataset.root).parent)
        else:
            self.threshold = conf.threshold
            self.loader, self.query_ds, self.gallery_ds = get_test_loader(conf)
Example #16
0
    def __init__(self,
                 multilingual_model,
                 config: AdversarialPretrainerConfig,
                 train_data,
                 test_data=None,
                 position=None,
                 seed=None):
        """
        :param multilingual_model: a multilingual sequence model which you want to train
        :param config: config of trainer containing parameters and total word vocab size
        :param train_data: a dictionary of dataloaders specifying train data
        :param test_data: a dictionary of dataloaders specifying test data, if none train_data is used instead
        """

        # Setup cuda device for BERT training, argument -c, --cuda should be true
        self.device = torch.device(config.gpu_id)

        # initialize public, private, and adversarial discriminator
        self.ltoi = config.language_ids
        self.model = AdversarialBertWrapper(multilingual_model, config)

        # move to GPU
        self.model.to(self.device)
        self.model = DistributedDataParallel(self.model, delay_allreduce=True)

        # assign data
        self.train_data = train_data
        self.test_data = test_data if test_data else train_data

        # initialize loss function and optimizers
        self.D_repeat = config.adv_repeat

        # initialize optimizers
        self.D_optim = BertAdam(
            self.model.module.component_parameters("adversary"), config.lr)
        self.lm_optims = BertAdam(self.model.module.component_parameters(),
                                  config.lr)

        # hyperparameters for loss
        self.beta = config.beta
        self.gamma = config.gamma

        # how many iterations to accumulate gradients for
        self.train_freq = config.train_freq if config.train_freq is not None else 1

        self._config = config  # for checkpointing
        self.position = position
        self.seed = seed
Example #17
0
 def _make_parallel(runner, net):
     if runner.configer.get('network.distributed', default=False):
         from apex.parallel import DistributedDataParallel
         torch.cuda.set_device(runner.configer.get('local_rank'))
         torch.distributed.init_process_group(backend='nccl',
                                              init_method='env://')
         net = DistributedDataParallel(net.cuda(), delay_allreduce=True)
         return net
     else:
         net = net.to(
             torch.device(
                 'cpu' if runner.configer.get('gpu') is None else 'cuda'))
         from exts.tools.parallel.data_parallel import ParallelModel
         return ParallelModel(net,
                              gather_=runner.configer.get(
                                  'network', 'gather'))
Example #18
0
    def distribute(self, args):
        labels = self.net.labels
        audio_conf = self.net.audio_conf
        main_proc = self.initialize_process_group(args)
        if main_proc and self._visdom:  # Add previous scores to visdom graph
            self._visdom_logger.load_previous_values(self._start_epoch, self._package)
        if main_proc and self._tensorboard:  # Previous scores to tensorboard logs
            self._tensorboard_logger.load_previous_values(self._start_epoch, self._package)
        self.net = DistributedDataParallel(self.net)

        self.net.decoder = GreedyDecoder(labels)
        self.net.audio_conf = audio_conf
        self.net.labels = labels
        self.net.device = self.device
        self.net = self.net.to(self.device)
        return main_proc
Example #19
0
def init_classifier(args):
    """Load a classifier, model and initialize a loss function.
    
    Returns:
        model: Classifier instance.
        criterion: Loss function.
        optimizer: Optimization instance for training.    
    """

    if args.architecture == 'basic':
        model = BasicNet(num_classes=args.num_classes).cuda()

    elif args.architecture == 'resnet50':
        model = resnet50(bn0=args.init_bn0,
                         num_classes=args.num_classes,
                         pretrained=args.pretrained).cuda()

    else:
        model = models.__dict__[args.architecture](
            pretrained=args.pretrained,
            num_classes=args.num_classes,
            aux_logits=False).cuda()

    if args.optimizer == 'sgd':
        # start with 0 lr. Scheduler will change this later
        optimizer = optim.SGD(model.parameters(),
                              0,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)
    elif args.optimizer == 'adadelta':
        optimizer = optim.Adadelta(model.parameters(), lr=1.0)
    else:
        raise NotImplementedError(
            f'Optimizer {args.optimizer} not implemented')

    if args.half_prec:
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level='O1',
                                          verbosity=0)

    if args.distributed:
        model = DistributedDataParallel(model)

    criterion = nn.CrossEntropyLoss().cuda()

    return model, criterion, optimizer
Example #20
0
    def __init__(self,
                 system_config,
                 model,
                 distributed=False,
                 gpu=None,
                 test=False):
        super(NetworkFactory, self).__init__()

        self.system_config = system_config

        self.gpu = gpu
        self.model = DummyModule(model)
        self.loss = model.loss
        self.network = Network(self.model, self.loss)
        self.freeze = False

        if distributed:
            from apex.parallel import DistributedDataParallel, convert_syncbn_model
            torch.cuda.set_device(gpu)
            self.network = self.network.cuda(gpu)
            self.network = convert_syncbn_model(self.network)
            self.network = DistributedDataParallel(self.network)
        else:
            self.network = DataParallel(self.network,
                                        chunk_sizes=system_config.chunk_sizes)

        total_params = 0
        for params in self.model.parameters():
            num_params = 1
            for x in params.size():
                num_params *= x
            total_params += num_params
        print("total parameters: {}".format(total_params))

        if system_config.opt_algo == "adam":
            self.optimizer = torch.optim.Adam(
                filter(lambda p: p.requires_grad, self.model.parameters()))
        elif system_config.opt_algo == "sgd":
            self.optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                                    self.model.parameters()),
                                             lr=system_config.learning_rate,
                                             momentum=0.9,
                                             weight_decay=0.0001)
        else:
            raise ValueError("unknown optimizer")
        if test:
            self.eval_mode()
Example #21
0
def process_components(
    model: _Model,
    criterion: _Criterion = None,
    optimizer: _Optimizer = None,
    scheduler: _Scheduler = None,
    distributed_params: Dict = None
) -> Tuple[_Model, _Criterion, _Optimizer, _Scheduler, torch.device]:
    distributed_params = distributed_params or {}
    distributed_params = copy.deepcopy(distributed_params)
    device = utils.get_device()

    if torch.cuda.is_available():
        benchmark = os.environ.get("CUDNN_BENCHMARK", "True") == "True"
        cudnn.benchmark = benchmark

    model = model.to(device)

    if utils.is_wrapped_with_ddp(model):
        pass
    elif len(distributed_params) > 0:
        utils.assert_fp16_available()
        from apex import amp

        distributed_rank = distributed_params.pop("rank", -1)

        if distributed_rank > -1:
            torch.cuda.set_device(distributed_rank)
            torch.distributed.init_process_group(
                backend="nccl", init_method="env://"
            )

        model, optimizer = amp.initialize(
            model, optimizer, **distributed_params
        )

        if distributed_rank > -1:
            from apex.parallel import DistributedDataParallel
            model = DistributedDataParallel(model)
        elif torch.cuda.device_count() > 1:
            model = torch.nn.DataParallel(model)
    elif torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)

    model = model.to(device)

    return model, criterion, optimizer, scheduler, device
Example #22
0
    def __init__(self, model: nn.Module, optimizer: Optimizer, loss_f: Callable, *,
                 callbacks: Callback = None, scheduler: LRScheduler = None,
                 verb=True, use_cudnn_benchmark=True, backend="nccl", init_method="env://",
                 use_sync_bn: bool = False, enable_amp=False, **kwargs):
        if use_sync_bn:
            model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
        if enable_amp:
            from homura import is_apex_available

            if not is_apex_available:
                raise RuntimeError("apex not installed")

        import sys as python_sys
        from torch import distributed

        # should be used with torch.distributed.launch
        if not is_distributed:
            raise RuntimeError(
                f"For distributed training, use python -m torch.distributed.launch "
                f"--nproc_per_node={torch.cuda.device_count()} {' '.join(python_sys.argv)} ...")

        distributed.init_process_group(backend=backend, init_method=init_method)
        rank = get_local_rank()
        if get_global_rank() > 0:
            # to avoid overwriting
            verb = False
        torch.cuda.set_device(rank)

        super(DistributedSupervisedTrainer, self).__init__(model, optimizer, loss_f, callbacks=callbacks,
                                                           scheduler=scheduler, verb=verb,
                                                           use_cudnn_benchmark=use_cudnn_benchmark,
                                                           use_cuda_nonblocking=True,
                                                           device=torch.device(GPU, rank), **kwargs)

        self.loss_scaler = None
        if enable_amp:
            from apex import amp
            from apex.parallel import DistributedDataParallel

            self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level="O2")
            self.model = DistributedDataParallel(self.model, delay_allreduce=True)
            self.loss_scaler = amp.scale_loss
        else:
            self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[rank])
Example #23
0
def main(args):
    cfg = setup(args)

    model = build_model(cfg)
    logger.info("Model:\n{}".format(model))
    if args.eval_only:
        cfg.defrost()
        cfg.MODEL.BACKBONE.PRETRAIN = False

        Checkpointer(model).load(cfg.MODEL.WEIGHTS)  # load trained model

        return do_test(cfg, model)

    distributed = comm.get_world_size() > 1
    if distributed:
        model = DistributedDataParallel(model, delay_allreduce=True)

    do_train(cfg, model, resume=args.resume)
    return do_test(cfg, model)
Example #24
0
    def __init__(self, opt):
        self.opt = opt

        self.pix2pix_model = models.create_model(opt)

        if opt.isTrain:
            self.optimizer_G, self.optimizer_D = self.pix2pix_model.create_optimizers(opt)
            self.old_lr = opt.lr

        if opt.fp16:
            self.pix2pix_model, [self.optimizer_G, self.optimizer_D] = amp.initialize(
                self.pix2pix_model, [self.optimizer_G, self.optimizer_D], num_losses=2)
        self.generated = None

        if opt.distributed:
            self.pix2pix_model = DistributedDataParallel(self.pix2pix_model, delay_allreduce=True)

        if opt.continue_train:
            self.load_checkpoint('latest')
Example #25
0
def main(cfgs):
    Logger.init(**cfgs['logger'])

    local_rank = cfgs['local_rank']
    world_size = int(os.environ['WORLD_SIZE'])
    Log.info('rank: {}, world_size: {}'.format(local_rank, world_size))

    log_dir = cfgs['log_dir']
    pth_dir = cfgs['pth_dir']
    if local_rank == 0:
        assure_dir(log_dir)
        assure_dir(pth_dir)

    aux_config = cfgs.get('auxiliary', None)
    network = ModuleBuilder(cfgs['network'], aux_config).cuda()
    criterion = build_criterion(cfgs['criterion'], aux_config).cuda()
    optimizer = optim.SGD(network.parameters(), **cfgs['optimizer'])
    scheduler = PolyLRScheduler(optimizer, **cfgs['scheduler'])

    dataset = build_dataset(**cfgs['dataset'], **cfgs['transforms'])
    sampler = DistributedSampler4Iter(dataset,
                                      world_size=world_size,
                                      rank=local_rank,
                                      **cfgs['sampler'])
    train_loader = DataLoader(dataset, sampler=sampler, **cfgs['loader'])

    cudnn.benchmark = True
    torch.manual_seed(666)
    torch.cuda.manual_seed(666)
    torch.cuda.set_device(local_rank)
    dist.init_process_group(backend='nccl', init_method='env://')

    model = DistributedDataParallel(network)
    model = apex.parallel.convert_syncbn_model(model)

    torch.cuda.empty_cache()
    train(local_rank, world_size, pth_dir, cfgs['frequency'], criterion,
          train_loader, model, optimizer, scheduler)
Example #26
0
    def __init__(self, model: nn.Module, optimizer: Optimizer, loss_f: Callable, *,
                 callbacks: Callback = None, scheduler: LRScheduler = None,
                 verb=True, use_cudnn_benchmark=True, backend="nccl", init_method="env://",
                 use_sync_bn: bool = False, enable_amp=False, **kwargs):

        if use_sync_bn:
            model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
        if enable_amp:
            from homura import is_apex_available

            if not is_apex_available:
                raise RuntimeError("apex not installed")

        init_distributed(backend, init_method, warning=False)
        rank = get_local_rank()
        if get_global_rank() > 0:
            # to avoid overwriting
            verb = False
        torch.cuda.set_device(rank)

        super(DistributedSupervisedTrainer, self).__init__(model, optimizer, loss_f, callbacks=callbacks,
                                                           scheduler=scheduler, verb=verb,
                                                           use_cudnn_benchmark=use_cudnn_benchmark,
                                                           use_cuda_nonblocking=True,
                                                           device=torch.device(GPU, rank), **kwargs)

        self.loss_scaler = None
        if enable_amp:
            from apex import amp
            from apex.parallel import DistributedDataParallel

            self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level="O2")
            self.model = DistributedDataParallel(self.model, delay_allreduce=True)
            self.loss_scaler = amp.scale_loss
        else:
            self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[rank])
Example #27
0
def main_worker(gpu, ngpus_per_node, args):
    torch.cuda.set_device(gpu)
    # device = torch.device(f'cuda:{args.local_rank}')
    assert torch.distributed.is_nccl_available()
    dist.init_process_group(backend='nccl', init_method='env://')
    assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled."

    if args.from_model is None:
        print("=> creating model")
        model = LiteTransformerEncoder(d_model=args.d_model,
                                       d_ff=args.d_ff,
                                       n_head=args.n_head,
                                       num_encoder_layers=args.n_layers,
                                       label_vocab_size=args.label_vocab_size,
                                       dropout=args.dropout).cuda()
        for p in model.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args.learning_rate,
                                     betas=(0.9, 0.999),
                                     weight_decay=args.weight_decay)
        model = convert_syncbn_model(model)  #Synchronize BN
        model, optimizer = amp.initialize(
            model, optimizer, opt_level=args.opt_level)  #need change
        model = DistributedDataParallel(model, delay_allreduce=True)

    else:  #need change
        print("=> loading from existing model")
        checkpoint = torch.load(args.from_model)
        model_opt = checkpoint['settings']
        model = LiteTransformerEncoder(
            d_model=model_opt.d_model,
            d_ff=model_opt.d_ff,
            n_head=model_opt.n_head,
            num_encoder_layers=model_opt.n_layers,
            label_vocab_size=model_opt.label_vocab_size,
            dropout=model_opt.dropout).cuda()
        args.start_epoch = checkpoint['epoch']
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args.learning_rate,
                                     betas=(0.9, 0.999),
                                     weight_decay=args.weight_decay)
        model = convert_syncbn_model(model)  #Synchronize BN
        model, optimizer = amp.initialize(
            model, optimizer, opt_level=args.opt_level)  #need change
        model = DistributedDataParallel(model, delay_allreduce=True)
        model.load_state_dict(checkpoint['model'])
        amp.load_state_dict(checkpoint['amp'])

    model.cuda()

    optimizer = ScheduledOptim(optimizer=optimizer,
                               d_model=args.d_model,
                               n_warmup_steps=args.warmup_steps)

    args.batch_size = int(args.batch_size / ngpus_per_node)
    args.world_size = ngpus_per_node
    train_dataset = TrainBatchBasecallDataset(
        signal_dir=args.train_signal_path, label_dir=args.train_label_path)
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset, num_replicas=args.world_size, rank=gpu)
    valid_dataset = TrainBatchBasecallDataset(signal_dir=args.test_signal_path,
                                              label_dir=args.test_label_path)
    valid_sampler = torch.utils.data.distributed.DistributedSampler(
        valid_dataset, num_replicas=args.world_size, rank=gpu)

    list_charcter_error = []
    list_valid_loss = []
    start = time.time()
    show_shape = True
    for epoch in range(args.start_epoch, args.epochs):
        train_sampler.set_epoch(epoch)  # used for pytorch >= 1.2
        train_provider = TrainBatchProvider(train_dataset,
                                            args.batch_size,
                                            num_workers=0,
                                            train_sampler=train_sampler,
                                            pin_memory=True,
                                            shuffle=(train_sampler is None))
        valid_provider = TrainBatchProvider(valid_dataset,
                                            args.batch_size,
                                            num_workers=0,
                                            train_sampler=valid_sampler,
                                            pin_memory=True,
                                            shuffle=False)
        # train
        model.train()
        total_loss = []
        batch_step = 0
        target_decoder = GreedyDecoder('-ATCG ', blank_index=0)
        decoder = BeamCTCDecoder('-ATCG ',
                                 cutoff_top_n=6,
                                 beam_width=3,
                                 blank_index=0)
        while True:
            batch = train_provider.next()
            signal, label = batch
            if signal is not None and label is not None:
                batch_step += 1
                if show_shape:
                    print('gpu {} signal shape:{}'.format(gpu, signal.size()),
                          flush=True)
                    print('gpu {} label shape:{}'.format(gpu, label.size()),
                          flush=True)
                    show_shape = False
                signal = signal.type(torch.FloatTensor).cuda(non_blocking=True)
                label = label.type(torch.LongTensor).cuda(non_blocking=True)
                # forward
                optimizer.zero_grad()
                signal_lengths = signal.squeeze(2).ne(constants.SIG_PAD).sum(1)
                enc_output, enc_output_lengths = model(
                    signal, signal_lengths)  # (N,L,C), [32, 256, 6]

                log_probs = enc_output.transpose(1, 0).log_softmax(
                    dim=-1)  # (L,N,C), [256,32,6]
                assert signal.size(2) == 1
                target_lengths = label.ne(constants.PAD).sum(1)

                concat_label = torch.flatten(label)
                concat_label = concat_label[concat_label.lt(constants.PAD)]

                loss = F.ctc_loss(log_probs,
                                  concat_label,
                                  enc_output_lengths,
                                  target_lengths,
                                  blank=0,
                                  reduction='sum')

                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()

                optimizer.step_and_update_lr()
                total_loss.append(loss.item() / signal.size(0))
                if batch_step % args.show_steps == 0:
                    print(
                        '{gpu:d} training: epoch {epoch:d}, step {step:d}, loss {loss:.6f}, time: {t:.3f}'
                        .format(gpu=gpu,
                                epoch=epoch + 1,
                                step=batch_step,
                                loss=np.mean(total_loss),
                                t=(time.time() - start) / 60),
                        flush=True)
                    start = time.time()
            else:
                print(
                    '{gpu:d} training: epoch {epoch:d}, step {step:d}, loss {loss:.6f}, time: {t:.3f}'
                    .format(gpu=gpu,
                            epoch=epoch + 1,
                            step=batch_step,
                            loss=np.mean(total_loss),
                            t=(time.time() - start) / 60),
                    flush=True)
                break
        # valid
        start = time.time()
        model.eval()
        total_loss = []
        batch_step = 0
        with torch.no_grad():
            total_wer, total_cer, num_tokens, num_chars = 0, 0, 0, 0
            while True:
                batch = valid_provider.next()
                signal, label = batch
                if signal is not None and label is not None:
                    batch_step += 1
                    signal = signal.type(
                        torch.FloatTensor).cuda(non_blocking=True)
                    label = label.type(
                        torch.LongTensor).cuda(non_blocking=True)

                    signal_lengths = signal.squeeze(2).ne(
                        constants.SIG_PAD).sum(1)
                    enc_output, enc_output_lengths = model(
                        signal, signal_lengths)

                    log_probs = enc_output.transpose(1, 0).log_softmax(
                        2)  # (L,N,C)

                    assert signal.size(2) == 1
                    # input_lengths = signal.squeeze(2).ne(constants.SIG_PAD).sum(1)
                    target_lengths = label.ne(constants.PAD).sum(1)
                    concat_label = torch.flatten(label)
                    concat_label = concat_label[concat_label.lt(constants.PAD)]

                    loss = F.ctc_loss(log_probs,
                                      concat_label,
                                      enc_output_lengths,
                                      target_lengths,
                                      blank=0,
                                      reduction='sum')
                    total_loss.append(loss.item() / signal.size(0))
                    if batch_step % args.show_steps == 0:
                        print(
                            '{gpu:d} validate: epoch {epoch:d}, step {step:d}, loss {loss:.6f}, time: {t:.3f}'
                            .format(gpu=gpu,
                                    epoch=epoch + 1,
                                    step=batch_step,
                                    loss=np.mean(total_loss),
                                    t=(time.time() - start) / 60),
                            flush=True)
                        start = time.time()

                    log_probs = log_probs.transpose(1, 0)  # (N,L,C)
                    target_strings = target_decoder.convert_to_strings(
                        label, target_lengths)
                    decoded_output, _ = decoder.decode(log_probs,
                                                       enc_output_lengths)

                    for x in range(len(label)):
                        transcript, reference = decoded_output[x][
                            0], target_strings[x][0]
                        cer_inst = decoder.cer(transcript, reference)
                        total_cer += cer_inst
                        num_chars += len(reference)
                else:
                    break
            cer = float(total_cer) / num_chars
            list_charcter_error.append(cer)
            list_valid_loss.append(np.mean(total_loss))
            print(
                '{gpu:d} validate: epoch {epoch:d}, loss {loss:.6f}, charcter error {cer:.3f} time: {time:.3f}'
                .format(gpu=gpu,
                        epoch=epoch + 1,
                        loss=np.mean(total_loss),
                        cer=cer * 100,
                        time=(time.time() - start) / 60))
            start = time.time()
            # remember best_cer and save checkpoint
            if cer <= min(list_charcter_error) and np.mean(total_loss) <= min(
                    list_valid_loss) and args.store_model:
                if gpu == 0:
                    model_name = ('%s_e%d_loss%.2f_cer%.2f.chkpt') % (
                        args.save_model, epoch + 1, np.mean(total_loss),
                        cer * 100)
                    checkpoint = {
                        'model': model.state_dict(),
                        # 'optimizer': optimizer.state_dict(),
                        'settings': args,
                        'amp': amp.state_dict(),
                        'epoch': epoch + 1
                    }
                    torch.save(checkpoint, model_name)
                    print('    - [Info] The checkpoint file has been updated.',
                          flush=True)
Example #28
0
def main_worker(local_rank, nprocs, args):
    best_mae = 99.0

    dist.init_process_group(backend='nccl')

    model = SFCN()

    torch.cuda.set_device(local_rank)
    model.cuda()
    # When using a single GPU per process and per
    # DistributedDataParallel, we need to divide the batch size
    # ourselves based on the total number of GPUs we have
    args.batch_size = int(args.batch_size / nprocs)

    # define loss function (criterion) and optimizer
    criterion = loss_func

    #optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    optimizer = t.optim.Adam(model.parameters(),
                             lr=args.lr,
                             weight_decay=args.weight_decay)
    model, optimizer = amp.initialize(model, optimizer)

    model = DistributedDataParallel(model)

    cudnn.benchmark = True

    # Data loading code
    train_data = CombinedData('', train=True)
    val_data = CombinedData('', train=False)

    train_sampler = torch.utils.data.distributed.DistributedSampler(train_data)

    train_loader = DataLoader(train_data,
                              args.batch_size,
                              shuffle=False,
                              num_workers=4,
                              pin_memory=True,
                              sampler=train_sampler)
    val_loader = DataLoader(val_data,
                            args.batch_size,
                            shuffle=False,
                            num_workers=4,
                            pin_memory=True)

    if args.evaluate:
        validate(val_loader, model, criterion, local_rank, args)
        return

    for epoch in range(args.start_epoch, args.epochs):
        train_sampler.set_epoch(epoch)

        adjust_learning_rate(optimizer, epoch, args)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, local_rank,
              args)

        # evaluate on validation set

        mae = validate(val_loader, model, criterion, local_rank, args)

        # remember best acc@1 and save checkpoint
        is_best = mae < best_mae
        best_mae = min(mae, best_mae)

        if not os.path.exists("checkpoints/%s" % args.env_name):
            os.makedirs("checkpoints/%s" % args.env_name)

        if is_best:
            if local_rank == 0:
                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'state_dict': model.module.state_dict(),
                        'best_mae': best_mae,
                        'amp': amp.state_dict(),
                        'optimizer': optimizer.state_dict(),
                    }, True, './checkpoints/%s/%s_epoch_%s_%s' %
                    (args.env_name, args.env_name, epoch, best_mae))
Example #29
0
def blend(config):
    """Run a panda training session."""
    # path to checkpoint models to blend
    if config.get('ckpt_dirs'):
        ckpt_dirs = config['ckpt_dirs']
    else:
        base_dir = os.path.join(config['input']['models'], config['experiment_name'])
        ckpt_dirs = [os.path.join(base_dir, fn) for fn in os.listdir(base_dir)]
    num_models = len(ckpt_dirs)

    # clean up the model directory and generate a new output path for the training session
    log_dir = os.path.join(config['output']['models'], config['experiment_name'])
    model_fname = f"blend_{num_models}_{datetime.datetime.now().strftime('%Y%m%d_%H%M')}"
    model_dir = os.path.join(log_dir, model_fname)
    if not os.path.exists(model_dir):
        os.makedirs(model_dir, exist_ok=True)

    # log activity from the training session to a logfile
    #if config['local_rank'] == 0:
    #    sys.stdout = utils.Tee(os.path.join(model_dir, 'train_history.log'))
    utils.set_state(config['random_state'])
    device_ids = config.get('device_ids', [0])
    device = torch.device(f"cuda:{device_ids[0]}" if torch.cuda.is_available() else "cpu")

    with open(os.path.join(model_dir, 'config.json'), 'w') as f:
        json.dump(config, f)

    # cuda settings
    if config['distributed']:
        torch.cuda.set_device(config['local_rank'])
        torch.distributed.init_process_group(
            backend='nccl',
            init_method='env://'
        )
    else:
        if isinstance(device_ids, list):
            visible_devices = ','.join([str(x) for x in device_ids])
        else:
            visible_devices = device_ids
        os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
        os.environ["CUDA_VISIBLE_DEVICES"] = visible_devices

    fold_ids = config['fold_ids']
    fold_ids = fold_ids if isinstance(fold_ids, list) else [fold_ids]

    # build each dataset and model to blend
    for fold in fold_ids:
        fold = int(fold) if fold.isdigit() else fold
        models = []
        train_dls = []
        val_dls = []
        test_dls = []
        for ckpt_dir in ckpt_dirs:
            model, train_dl, val_dl, test_dl = get_branch(ckpt_dir=ckpt_dir,
                                                          root=config['input']['root'],
                                                          cv_folds_dir=config['input']['cv_folds'],
                                                          train=config['input']['train'],
                                                          fold=fold,
                                                          patient=config['input'].get('patient'),
                                                          blacklist=config['input'].get('blacklist'),
                                                          batch_size=config['batch_size'],
                                                          num_workers=config['num_workers'],
                                                          keep_prob=config.get('keep_prob', 1),
                                                          verbose=config.get('verbose', 1))
            models.append(model)
            train_dls.append(train_dl)
            val_dls.append(val_dl)
            test_dls.append(test_dl)

        model = BlendModel(models, **config['model'].get('params', {}))
        if config['input'].get('pretrained'):
            print(f"Loading pretrained weights from {config['input']['pretrained']}")
            model.load_state_dict(torch.load(config['input']['pretrained']))

        model = model.cuda()
        if config['distributed']:
            model = convert_syncbn_model(model)
            model = DistributedDataParallel(model, delay_allreduce=True)
        else:
            model = nn.DataParallel(model, device_ids).to(device)

        optim = train_utils.get_optimizer(model=model, **config['optimizer'])
        sched = train_utils.get_scheduler(config,
                                          optim,
                                          steps_per_epoch=len(train_dl))
        criterion = train_utils.get_criterion(config)
        postprocessor = train_utils.get_postprocessor(**config['postprocessor']) if config.get('postprocessor') else None

        if config['distributed'] and config['fp_16']:
            model, optim = amp.initialize(model,
                                          optim,
                                          opt_level=config.get('opt_level', 'O2'),
                                          loss_scale='dynamic',
                                          keep_batchnorm_fp32=config.get('keep_batchnorm_fp32', True))

        ckpt_dir = os.path.join(model_dir, f'fold_{fold}')
        print(f'Checkpoint path {ckpt_dir}\n')
        if not os.path.exists(ckpt_dir):
            os.makedirs(ckpt_dir)

        trainer = BlendTrainer(model,
                               optim,
                               criterion,
                               postprocessor=postprocessor,
                               scheduler=sched,
                               ckpt_dir=ckpt_dir,
                               fp_16=config['fp_16'],
                               r2ank=config.get('local_rank', 0),
                               **config.get('trainer', {}))
        trainer.fit(train_dls, config['steps'], val_dls)

        # generate predictions table from the best step model
        if config['local_rank'] == 0:
            trainer.load_model(step=trainer.best_step, models=models)
            df_pred_val = generate_df_pred(trainer,
                                           val_dls,
                                           val_dls[0].dataset.df,
                                           postprocessor=postprocessor,
                                           mode='blend')
            df_pred_val.to_csv(os.path.join(ckpt_dir, 'val_predictions.csv'), index=False)
            log_model_summary(df_pred_val, logger=trainer.logger, group='val')
            if postprocessor is not None:
                print('Updating postprocessor val predictions')
                postprocessor.fit(df_pred_val['prediction_raw'], df_pred_val[config['target_col']])

            if test_dl is not None:
                df_pred_test = generate_df_pred(trainer,
                                                test_dls,
                                                test_dls[0].dataset.df,
                                                postprocessor=postprocessor,
                                                mode='blend',
                                                num_bags=config.get('num_bags'))
                df_pred_test.to_csv(os.path.join(ckpt_dir, 'test_predictions.csv'), index=False)
                log_model_summary(df_pred_test, logger=trainer.logger, group='test')

            if postprocessor is not None:
                np.save(os.path.join(ckpt_dir, 'coef.npy'), postprocessor.get_coef())
            print(f'Saved output to {ckpt_dir}')
def main():
    parser = argparse.ArgumentParser("PyTorch Xview Pipeline")
    arg = parser.add_argument
    arg('--config', metavar='CONFIG_FILE', help='path to configuration file')
    arg('--workers', type=int, default=6, help='number of cpu threads to use')
    arg('--gpu',
        type=str,
        default='0',
        help='List of GPUs for parallel training, e.g. 0,1,2,3')
    arg('--output-dir', type=str, default='weights/')
    arg('--resume', type=str, default='')
    arg('--fold', type=int, default=0)
    arg('--prefix', type=str, default='classifier_')
    arg('--data-dir', type=str, default="/mnt/sota/datasets/deepfake")
    arg('--val-dir', type=str, default="../dfdc_train_all/dfdc_test")
    arg('--folds-csv', type=str, default='folds.csv')
    arg('--val-folds-csv', type=str)
    arg('--crops-dir', type=str, default='crops')
    arg('--label-smoothing', type=float, default=0.01)
    arg('--logdir', type=str, default='logs')
    arg('--zero-score', action='store_true', default=False)
    arg('--from-zero', action='store_true', default=False)
    arg('--distributed', action='store_true', default=False)
    arg('--freeze-epochs', type=int, default=0)
    arg("--local_rank", default=0, type=int)
    arg("--seed", default=777, type=int)
    arg("--padding-part", default=3, type=int)
    arg("--opt-level", default='O1', type=str)
    arg("--test_every", type=int, default=1)
    arg("--no-oversample", action="store_true")
    arg("--no-hardcore", action="store_true")
    arg("--only-changed-frames", action="store_true")

    args = parser.parse_args()
    os.makedirs(args.output_dir, exist_ok=True)
    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
    else:
        os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    cudnn.benchmark = True

    conf = load_config(args.config)
    model = classifiers.__dict__[conf['network']](encoder=conf['encoder'])

    model = model.cuda()
    if args.distributed:
        model = convert_syncbn_model(model)
    ohem = conf.get("ohem_samples", None)
    reduction = "mean"
    if ohem:
        reduction = "none"
    loss_fn = []
    weights = []
    for loss_name, weight in conf["losses"].items():
        loss_fn.append(losses.__dict__[loss_name](reduction=reduction).cuda())
        weights.append(weight)
    loss = WeightedLosses(loss_fn, weights)
    loss_functions = {"classifier_loss": loss}
    optimizer, scheduler = create_optimizer(conf['optimizer'], model)
    bce_best = 100
    start_epoch = 0
    batch_size = conf['optimizer']['batch_size']
    print("Config Loaded")
    data_train = DeepFakeClassifierDataset(
        mode="train",
        oversample_real=not args.no_oversample,
        fold=args.fold,
        padding_part=args.padding_part,
        hardcore=not args.no_hardcore,
        crops_dir=args.crops_dir,
        data_path=args.data_dir,
        label_smoothing=args.label_smoothing,
        folds_csv=args.folds_csv,
        transforms=create_train_transforms(conf["size"]),
        normalize=conf.get("normalize", None))
    print("train data Loaded")
    data_val = DeepFakeClassifierDataset(mode="val",
                                         fold=args.fold,
                                         padding_part=args.padding_part,
                                         crops_dir=args.crops_dir,
                                         data_path=args.data_dir,
                                         folds_csv=args.folds_csv,
                                         transforms=create_val_transforms(
                                             conf["size"]),
                                         normalize=conf.get("normalize", None))
    print("val data Loaded")
    val_data_loader = DataLoader(data_val,
                                 batch_size=batch_size * 2,
                                 num_workers=args.workers,
                                 shuffle=False,
                                 pin_memory=False)
    os.makedirs(args.logdir, exist_ok=True)
    summary_writer = SummaryWriter(args.logdir + '/' +
                                   conf.get("prefix", args.prefix) +
                                   conf['encoder'] + "_" + str(args.fold))
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location='cpu')
            state_dict = checkpoint['state_dict']
            state_dict = {k[7:]: w for k, w in state_dict.items()}
            model.load_state_dict(state_dict, strict=False)
            if not args.from_zero:
                start_epoch = checkpoint['epoch']
                if not args.zero_score:
                    bce_best = checkpoint.get('bce_best', 0)
            print("=> loaded checkpoint '{}' (epoch {}, bce_best {})".format(
                args.resume, checkpoint['epoch'], checkpoint['bce_best']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    if args.from_zero:
        start_epoch = 0
    current_epoch = start_epoch

    if conf['fp16']:
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.opt_level,
                                          loss_scale='dynamic')

    snapshot_name = "{}{}_{}_{}".format(conf.get("prefix",
                                                 args.prefix), conf['network'],
                                        conf['encoder'], args.fold)

    if args.distributed:
        model = DistributedDataParallel(model, delay_allreduce=True)
    else:
        model = DataParallel(model).cuda()
    data_val.reset(1, args.seed)
    max_epochs = conf['optimizer']['schedule']['epochs']
    for epoch in range(start_epoch, max_epochs):
        data_train.reset(epoch, args.seed)
        train_sampler = None
        if args.distributed:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                data_train)
            train_sampler.set_epoch(epoch)
        if epoch < args.freeze_epochs:
            print("Freezing encoder!!!")
            model.module.encoder.eval()
            for p in model.module.encoder.parameters():
                p.requires_grad = False
        else:
            model.module.encoder.train()
            for p in model.module.encoder.parameters():
                p.requires_grad = True

        train_data_loader = DataLoader(data_train,
                                       batch_size=batch_size,
                                       num_workers=args.workers,
                                       shuffle=train_sampler is None,
                                       sampler=train_sampler,
                                       pin_memory=False,
                                       drop_last=True)

        train_epoch(current_epoch, loss_functions, model, optimizer, scheduler,
                    train_data_loader, summary_writer, conf, args.local_rank,
                    args.only_changed_frames)
        model = model.eval()

        if args.local_rank == 0:
            torch.save(
                {
                    'epoch': current_epoch + 1,
                    'state_dict': model.state_dict(),
                    'bce_best': bce_best,
                }, args.output_dir + '/' + snapshot_name + "_last")
            torch.save(
                {
                    'epoch': current_epoch + 1,
                    'state_dict': model.state_dict(),
                    'bce_best': bce_best,
                },
                args.output_dir + snapshot_name + "_{}".format(current_epoch))
            if (epoch + 1) % args.test_every == 0:
                bce_best = evaluate_val(args,
                                        val_data_loader,
                                        bce_best,
                                        model,
                                        snapshot_name=snapshot_name,
                                        current_epoch=current_epoch,
                                        summary_writer=summary_writer)
        current_epoch += 1