Example #1
0
def custom_trainer(model, device, RE_train_dataset, RE_valid_dataset, args):
    train_loader = DataLoader(RE_train_dataset,
                              batch_size=args['train_batch_size'],
                              shuffle=True,
                              num_workers=args['num_workers'])
    valid_loader = DataLoader(RE_valid_dataset,
                              batch_size=args['eval_batch_size'],
                              shuffle=True,
                              num_workers=args['num_workers'])

    optim = AdamW(model.parameters(), lr=args['lr'])
    loss_fn = LabelSmoothingLoss()

    model.train()

    EPOCHS, print_every = args['epochs'], 1

    for epoch in range(EPOCHS):
        loss_val_sum = 0

        for batch in tqdm(train_loader):
            optim.zero_grad()
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            outputs = model(input_ids,
                            attention_mask=attention_mask,
                            labels=labels)

            loss = outputs[0]

            loss.backward()
            optim.step()
            loss_val_sum += loss
        loss_val_avg = loss_val_sum / len(train_loader)

        if ((epoch % print_every) == 0 or epoch == (EPOCHS - 1)):

            train_accr = func_eval(model, train_loader, device)
            valid_accr = func_eval(model, valid_loader, device)
            print(
                "epoch:[%d] loss:[%.3f] train_accr:[%.3f] valid_accr:[%.3f]." %
                (epoch, loss_val_avg, train_accr, valid_accr))
Example #2
0
    def __init__(self,
                 encoder,
                 decoder,
                 optimizer_params={},
                 amp_params={},
                 n_jobs=0,
                 rank=0):

        lr = optimizer_params.get('lr', 1e-3)
        weight_decay = optimizer_params.get('weight_decay', 0)
        warmap = optimizer_params.get('warmap', 100)
        amsgrad = optimizer_params.get('amsgrad', False)
        opt_level = amp_params.get('opt_level', 'O0')
        loss_scale = amp_params.get('loss_scale', None)

        self.device = torch.device('cuda:' + str(rank))
        self.encoder = encoder.to(self.device)
        #self.decoder = decoder.to(self.device)
        self.num_classes = decoder.num_classes
        self.mse_critetion = nn.L1Loss()
        self.ce_criterion = LabelSmoothingLoss(self.num_classes,
                                               smoothing=0.1,
                                               reduction='none').to(
                                                   self.device)
        self.vat_criterion = VATLoss()
        self.cutmix = CutMix(self.num_classes)

        param_optimizer = list(self.encoder.named_parameters()
                               )  #+ list(self.decoder.named_parameters())
        no_decay = ['bn', 'bias']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            weight_decay
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0
        }]

        self.optimizer = RAdam(optimizer_grouped_parameters,
                               lr=lr,
                               weight_decay=weight_decay)

        self.is_master = torch.distributed.get_rank() == 0
        torch.cuda.set_device(rank)
        [self.encoder
         ], self.optimizer = apex.amp.initialize([self.encoder],
                                                 self.optimizer,
                                                 opt_level=opt_level,
                                                 loss_scale=loss_scale,
                                                 verbosity=1)

        self.scheduler = StepLR(self.optimizer, step_size=20, gamma=0.5)

        self.encoder = apex.parallel.DistributedDataParallel(
            self.encoder, delay_allreduce=True)
        #self.decoder = apex.parallel.DistributedDataParallel(self.decoder, delay_allreduce=True)

        self.last_epoch = 0
        self.n_jobs = n_jobs
Example #3
0
tl = [("i" + str(i), "t" + str(i)) for i in range(ntrain)]

if fine_tune_m is None:
    mymodel = init_model_params(mymodel)
    mymodel.apply(init_fixing)
else:
    logger.info("Load pre-trained model from: " + fine_tune_m)
    mymodel = load_model_cpu(fine_tune_m, mymodel)

#lw = torch.ones(nwordt).float()
#lw[0] = 0.0
#lossf = nn.NLLLoss(lw, ignore_index=0, reduction='sum')
lossf = LabelSmoothingLoss(nwordt,
                           cnfg.label_smoothing,
                           ignore_index=0,
                           reduction='sum',
                           forbidden_index=cnfg.forbidden_indexes)

if cnfg.src_emb is not None:
    logger.info("Load source embedding from: " + cnfg.src_emb)
    _emb = torch.load(cnfg.src_emb, map_location='cpu')
    if nwordi < _emb.size(0):
        _emb = _emb.narrow(0, 0, nwordi).contiguous()
    if cnfg.scale_down_emb:
        _emb.div_(sqrt(cnfg.isize))
    mymodel.enc.wemb.weight.data = _emb.data
    if cnfg.freeze_srcemb:
        mymodel.enc.wemb.weight.requires_grad_(False)
    else:
        mymodel.enc.wemb.weight.requires_grad_(True)
Example #4
0
#encoding: utf-8

import torch
import torch.nn.functional as F
from loss import LabelSmoothingLoss

lossf = LabelSmoothingLoss(8,
                           label_smoothing=0.1,
                           ignore_index=0,
                           reduction='none',
                           forbidden_index=3)
target = torch.ones(5, 1).long()
target.data[0] = 0
target.data[1] = 1
target.data[2] = 2
target.data[3] = 4
td = torch.randn(5, 8)
#td.narrow(1, 3, 1).fill_(-1e32)
td.requires_grad_(True)
print(td)
output = F.log_softmax(td, -1)
print(output)
cost = lossf(output, target)
print(cost)
cost.sum().backward()
print(output.grad)
print(td.grad)
Example #5
0
                    Config.num_head, Config.head_size, Config.feedforward_size,
                    Config.dropout, Config.attn_dropout, Config.layer_norm_eps)
model.cuda()

# load model if pretrained
if Config.pretrained:
    model.load_state_dict(torch.load(Config.model_load))

# optimizer
optimizer = NoamOpt(
    Config.hidden_size, Config.factor, Config.warmup,
    Adam(model.parameters(), lr=Config.lr, betas=(0.9, 0.98), eps=1e-9))

# criterion
criterion = LabelSmoothingLoss(0.1,
                               tgt_vocab_size=output_lang.n_words,
                               ignore_index=Config.PADDING_token).cuda()
#criterion = nn.NLLLoss()

# make ckpts save
if not os.path.exists('ckpts'):
    os.makedirs('ckpts')

# training
best_bleu = -1
for i in range(Config.n_epoch):
    train_one_epoch(train_loader, model, optimizer, criterion, print_every=50)
    if i % 5 == 0:
        evaluateRandomly(pairs_dev, input_lang, output_lang, model, n=3)
    acc, bleu = evaluate_dataset(dev_loader, model, output_lang)
    print("accuracy: {}  bleu score: {}".format(acc, bleu))
Example #6
0
def main():
    ''' Main function '''
    rid = cnfg.runid  # Get run ID from cnfg file where training files will be stored
    if len(sys.argv) > 1:
        rid = sys.argv[1]  # getting runid from console

    earlystop = cnfg.earlystop  # Get early-stop criteria
    epochs = cnfg.epochs  #

    tokens_optm = cnfg.tokens_optm  # number of tokens

    done_tokens = tokens_optm

    batch_report = cnfg.batch_report
    report_eva = cnfg.report_eva

    use_cuda = cnfg.use_cuda
    gpuid = cnfg.gpuid

    # GPU configuration
    if use_cuda and torch.cuda.is_available():
        use_cuda = True
        if len(gpuid.split(",")) > 1:
            cuda_device = torch.device(gpuid[:gpuid.find(",")].strip())
            cuda_devices = [
                int(_.strip()) for _ in gpuid[gpuid.find(":") + 1:].split(",")
            ]
            print('[Info] using multiple gpu', cuda_devices)
            multi_gpu = True
        else:
            cuda_device = torch.device(gpuid)
            multi_gpu = False
            print('[Info] using single gpu', cuda_device)
            cuda_devices = None
        torch.cuda.set_device(cuda_device.index)
    else:
        cuda_device = False
        print('using single cpu')
        multi_gpu = False
        cuda_devices = None

    use_ams = cnfg.use_ams  # ?

    save_optm_state = cnfg.save_optm_state

    save_every = cnfg.save_every

    epoch_save = cnfg.epoch_save

    remain_steps = cnfg.training_steps

    wkdir = "".join((cnfg.work_dir, cnfg.data_dir, "/", rid,
                     "/"))  # CREATING MODEL DIRECTORY
    if not path_check(wkdir):
        makedirs(wkdir)

    chkpt = None
    chkptoptf = None
    chkptstatesf = None
    if save_every is not None:
        chkpt = wkdir + "checkpoint.t7"
        if save_optm_state:
            chkptoptf = wkdir + "checkpoint.optm.t7"
            chkptstatesf = wkdir + "checkpoint.states"

    logger = get_logger(wkdir + "train.log")  # Logger object

    train_data = h5py.File(cnfg.train_data,
                           "r")  # training data read from h5 file
    valid_data = h5py.File(cnfg.dev_data,
                           "r")  # validation data read from h5 file

    print('[Info] Training and Validation data are loaded.')

    ntrain = int(
        train_data["ndata"][:][0])  # number of batches for TRAINING DATA
    nvalid = int(
        valid_data["ndata"][:][0])  # number of batches for VALIDATION DATA
    nwordi = int(train_data["nwordi"][:][0])  # VOCAB SIZE FOR SOURCE
    nwordt = int(
        train_data["nwordt"][:][0])  # VOCAB SIZE FOR PE [TODO: SIMILAR FOR MT]

    print('[INFO] number of batches for TRAINING DATA: ', ntrain)
    print('[INFO] number of batches for VALIDATION DATA: ', nvalid)
    print('[INFO] Source vocab size: ', nwordi)
    print('[INFO] Target vocab size: ', nwordt)

    random_seed = torch.initial_seed() if cnfg.seed is None else cnfg.seed

    rpyseed(random_seed)

    if use_cuda:
        torch.cuda.manual_seed_all(random_seed)
        print('[Info] Setting up random seed using CUDA.')
    else:
        torch.manual_seed(random_seed)

    logger.info("Design models with seed: %d" % torch.initial_seed())

    mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.num_src_layer,
                  cnfg.num_mt_layer, cnfg.num_pe_layer, cnfg.ff_hsize,
                  cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead,
                  cnfg.cache_len, cnfg.attn_hsize, cnfg.norm_output,
                  cnfg.bindDecoderEmb,
                  cnfg.forbidden_indexes)  # TODO NEED DOCUMENTATION

    tl = [("i" + str(i), "m" + str(i), "t" + str(i))
          for i in range(ntrain)]  # TRAINING LIST

    fine_tune_m = cnfg.fine_tune_m
    # Fine tune model

    if fine_tune_m is not None:
        logger.info("Load pre-trained model from: " + fine_tune_m)
        mymodel = load_model_cpu(fine_tune_m, mymodel)

    lossf = LabelSmoothingLoss(nwordt,
                               cnfg.label_smoothing,
                               ignore_index=0,
                               reduction='sum',
                               forbidden_index=cnfg.forbidden_indexes)

    if use_cuda:
        mymodel.to(cuda_device)
        lossf.to(cuda_device)

    if fine_tune_m is None:
        for p in mymodel.parameters():
            if p.requires_grad and (p.dim() > 1):
                xavier_uniform_(p)
        if cnfg.src_emb is not None:
            _emb = torch.load(cnfg.src_emb, map_location='cpu')
            if nwordi < _emb.size(0):
                _emb = _emb.narrow(0, 0, nwordi).contiguous()
            if use_cuda:
                _emb = _emb.to(cuda_device)
            mymodel.enc.wemb.weight.data = _emb
            if cnfg.freeze_srcemb:
                mymodel.enc.wemb.weight.requires_grad_(False)
            else:
                mymodel.enc.wemb.weight.requires_grad_(True)
        if cnfg.tgt_emb is not None:
            _emb = torch.load(cnfg.tgt_emb, map_location='cpu')
            if nwordt < _emb.size(0):
                _emb = _emb.narrow(0, 0, nwordt).contiguous()
            if use_cuda:
                _emb = _emb.to(cuda_device)
            mymodel.dec.wemb.weight.data = _emb
            if cnfg.freeze_tgtemb:
                mymodel.dec.wemb.weight.requires_grad_(False)
            else:
                mymodel.dec.wemb.weight.requires_grad_(True)
        mymodel.apply(init_fixing)

    # lr will be over written by GoogleLR before used
    optimizer = optim.Adam(mymodel.parameters(),
                           lr=1e-4,
                           betas=(0.9, 0.98),
                           eps=1e-9,
                           weight_decay=cnfg.weight_decay,
                           amsgrad=use_ams)

    # TODO: Need to implement
    '''if multi_gpu:
        # mymodel = nn.DataParallel(mymodel, device_ids=cuda_devices, output_device=cuda_device.index)
        mymodel = DataParallelMT(mymodel, device_ids=cuda_devices, output_device=cuda_device.index, host_replicate=True,
                                 gather_output=False)
        lossf = DataParallelCriterion(lossf, device_ids=cuda_devices, output_device=cuda_device.index,
                                      replicate_once=True)'''

    # Load fine tune state if declared
    fine_tune_state = cnfg.fine_tune_state
    if fine_tune_state is not None:
        logger.info("Load optimizer state from: " + fine_tune_state)
        optimizer.load_state_dict(torch.load(fine_tune_state))

    lrsch = GoogleLR(optimizer, cnfg.isize, cnfg.warm_step)
    lrsch.step()

    num_checkpoint = cnfg.num_checkpoint
    cur_checkid = 0  # initialized current check point

    tminerr = float("inf")  # minimum error during training

    minloss, minerr = eva(valid_data, nvalid, mymodel, lossf, cuda_device,
                          multi_gpu)
    logger.info("".join(("Init lr: ", ",".join(tostr(getlr(optimizer))),
                         ", Dev Loss/Error: %.3f %.2f" % (minloss, minerr))))

    # if fine_tune_m is None:
    save_model(mymodel, wkdir + "init.t7", multi_gpu)
    logger.info("Initial model saved")
    # ==================================================Fine tune ========================================
    if fine_tune_m is None:
        save_model(mymodel, wkdir + "init.t7", multi_gpu)
        logger.info("Initial model saved")
    else:
        cnt_states = cnfg.train_statesf
        if (cnt_states is not None) and path_check(cnt_states):
            logger.info("Continue last epoch")
            args = {}
            tminerr, done_tokens, cur_checkid, remain_steps, _ = train(
                train_data, load_states(cnt_states), valid_data, nvalid,
                optimizer, lrsch, mymodel, lossf, cuda_device, logger,
                done_tokens, multi_gpu, tokens_optm, batch_report, save_every,
                chkpt, chkptoptf, chkptstatesf, num_checkpoint, cur_checkid,
                report_eva, remain_steps, False)
            vloss, vprec = eva(valid_data, nvalid, mymodel, lossf, cuda_device,
                               multi_gpu)
            logger.info(
                "Epoch: 0, train loss: %.3f, valid loss/error: %.3f %.2f" %
                (tminerr, vloss, vprec))
            save_model(
                mymodel,
                wkdir + "train_0_%.3f_%.3f_%.2f.t7" % (tminerr, vloss, vprec),
                multi_gpu)
            if save_optm_state:
                torch.save(
                    optimizer.state_dict(), wkdir +
                    "train_0_%.3f_%.3f_%.2f.optm.t7" % (tminerr, vloss, vprec))
            logger.info("New best model saved")

        # assume that the continue trained model has already been through sort grad, thus shuffle the training list.
        shuffle(tl)
    # ====================================================================================================

    # ================================Dynamic sentence Sampling =========================================
    if cnfg.dss_ws is not None and cnfg.dss_ws > 0.0 and cnfg.dss_ws < 1.0:
        dss_ws = int(cnfg.dss_ws * ntrain)
        _Dws = {}
        _prev_Dws = {}
        _crit_inc = {}
        if cnfg.dss_rm is not None and cnfg.dss_rm > 0.0 and cnfg.dss_rm < 1.0:
            dss_rm = int(cnfg.dss_rm * ntrain * (1.0 - cnfg.dss_ws))
        else:
            dss_rm = 0
    else:
        dss_ws = 0
        dss_rm = 0
        _Dws = None
    # ====================================================================================================

    namin = 0

    # TRAINING EPOCH STARTS
    for i in range(1, epochs + 1):
        terr, done_tokens, cur_checkid, remain_steps, _Dws = train(
            train_data, tl, valid_data, nvalid, optimizer, lrsch, mymodel,
            lossf, cuda_device, logger, done_tokens, multi_gpu, tokens_optm,
            batch_report, save_every, chkpt, chkptoptf, chkptstatesf,
            num_checkpoint, cur_checkid, report_eva, remain_steps, dss_ws > 0)
        # VALIDATION
        vloss, vprec = eva(valid_data, nvalid, mymodel, lossf, cuda_device,
                           multi_gpu)
        logger.info(
            "Epoch: %d ||| train loss: %.3f ||| valid loss/error: %.3f/%.2f" %
            (i, terr, vloss, vprec))

        # CONDITION TO SAVE MODELS
        if (vprec <= minerr) or (vloss <= minloss):
            save_model(
                mymodel,
                wkdir + "eva_%d_%.3f_%.3f_%.2f.t7" % (i, terr, vloss, vprec),
                multi_gpu)
            if save_optm_state:
                torch.save(
                    optimizer.state_dict(), wkdir +
                    "eva_%d_%.3f_%.3f_%.2f.optm.t7" % (i, terr, vloss, vprec))
            logger.info("New best model saved"
                        )  # [TODO CALCULATE BLEU FOR VALIDATION SET]

            namin = 0

            if vprec < minerr:
                minerr = vprec
            if vloss < minloss:
                minloss = vloss

        else:
            if terr < tminerr:
                tminerr = terr
                save_model(
                    mymodel, wkdir + "train_%d_%.3f_%.3f_%.2f.t7" %
                    (i, terr, vloss, vprec), multi_gpu)
                if save_optm_state:
                    torch.save(
                        optimizer.state_dict(),
                        wkdir + "train_%d_%.3f_%.3f_%.2f.optm.t7" %
                        (i, terr, vloss, vprec))
            elif epoch_save:
                save_model(
                    mymodel, wkdir + "epoch_%d_%.3f_%.3f_%.2f.t7" %
                    (i, terr, vloss, vprec), multi_gpu)

            namin += 1
            # CONDITIONED TO EARLY STOP
            if namin >= earlystop:
                if done_tokens > 0:
                    if multi_gpu:
                        mymodel.collect_gradients()
                    optimizer.step()
                    # lrsch.step()
                    done_tokens = 0
                # optimizer.zero_grad()
                logger.info("early stop")
                break

        if remain_steps is not None and remain_steps <= 0:
            logger.info("Last training step reached")
            break
        '''if dss_ws > 0:
            if _prev_Dws:
                for _key, _value in _Dws.items():
                    if _key in _prev_Dws:
                        _ploss = _prev_Dws[_key]
                        _crit_inc[_key] = (_ploss - _value) / _ploss
                tl = dynamic_sample(_crit_inc, dss_ws, dss_rm)
            _prev_Dws = _Dws'''

        shuffle(tl)
        '''oldlr = getlr(optimizer)
        lrsch.step(terr)
        newlr = getlr(optimizer)
        if updated_lr(oldlr, newlr):
          logger.info("".join(("lr update from: ", ",".join(tostr(oldlr)), ", to: ", ",".join(tostr(newlr)))))
          hook_lr_update(optimizer, use_ams)'''

    if done_tokens > 0:
        if multi_gpu:
            mymodel.collect_gradients()
        optimizer.step()
    # lrsch.step()
    # done_tokens = 0
    # optimizer.zero_grad()

    save_model(mymodel, wkdir + "last.t7", multi_gpu)
    if save_optm_state:
        torch.save(optimizer.state_dict(), wkdir + "last.optm.t7")
    logger.info("model saved")

    train_data.close()
    valid_data.close()
Example #7
0
                                        lr=args.lr,
                                        momentum=0.9)
        elif args.optim == 'adam':
            optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
        elif args.optim == 'adamp':
            optimizer = AdamP(model.parameters(),
                              lr=args.lr,
                              betas=(0.9, 0.999),
                              weight_decay=1e-2)
        else:
            raise NameError('Not a optimizer available.')

        if args.loss == 'cross_entropy':
            criterion = nn.CrossEntropyLoss()
        elif args.loss == 'f1':
            criterion = F1Loss()
        elif args.loss == 'focal':
            criterion = FocalLoss()
        elif args.loss == 'label_smoothing':
            criterion = LabelSmoothingLoss(smoothing=args.smoothing)
        else:
            raise NameError('Not a loss function available.')

        seed_everything(args.seed)
        training(model, optimizer, criterion, trainloader, validloader,
                 args.epochs, args.model_name)

    else:

        test(model, args.model_name, str(args.chpkt_idx))
Example #8
0
def train_model_on_dataset(rank, cfg):
    dist_rank = rank
    # print(dist_rank)
    dist.init_process_group(backend="nccl", rank=dist_rank,
                            world_size=cfg.num_gpu,
                            init_method="env://")
    torch.cuda.set_device(rank)
    cudnn.benchmark = True
    dataset = CityFlowNLDataset(cfg, build_transforms(cfg))

    model = MyModel(cfg, len(dataset.nl), dataset.nl.word_to_idx['<PAD>'], norm_layer=nn.SyncBatchNorm, num_colors=len(CityFlowNLDataset.colors), num_types=len(CityFlowNLDataset.vehicle_type) - 2).cuda()
    model = DistributedDataParallel(model, device_ids=[rank],
                                    output_device=rank,
                                    broadcast_buffers=cfg.num_gpu > 1, find_unused_parameters=False)
    optimizer = torch.optim.Adam(
            params=model.parameters(),
            lr=cfg.TRAIN.LR.BASE_LR, weight_decay=0.00003)
    lr_scheduler = WarmupMultiStepLR(optimizer,
                            milestones=cfg.TRAIN.STEPS,
                            gamma=cfg.TRAIN.LR.WEIGHT_DECAY,
                            warmup_factor=cfg.TRAIN.WARMUP_FACTOR,
                            warmup_iters=cfg.TRAIN.WARMUP_EPOCH)
    color_loss = LabelSmoothingLoss(len(dataset.colors), 0.1)
    vehicle_loss = LabelSmoothingLoss(len(dataset.vehicle_type) - 2, 0.1)
    if cfg.resume_epoch > 0:
        model.load_state_dict(torch.load(f'save/{cfg.resume_epoch}.pth'))
        optimizer.load_state_dict(torch.load(f'save/{cfg.resume_epoch}_optim.pth'))
        lr_scheduler.last_epoch = cfg.resume_epoch
        lr_scheduler.step()
        if rank == 0:
            print(f'resume from {cfg.resume_epoch} pth file, starting {cfg.resume_epoch+1} epoch')
        cfg.resume_epoch += 1

    # loader = DataLoader(dataset, batch_size=cfg.TRAIN.BATCH_SIZE, shuffle=True, num_workers=cfg.TRAIN.NUM_WORKERS)
    train_sampler = DistributedSampler(dataset)
    loader = DataLoader(dataset, batch_size=cfg.TRAIN.BATCH_SIZE //cfg.num_gpu,
                            num_workers=cfg.TRAIN.NUM_WORKERS // cfg.num_gpu,# shuffle=True,
                            sampler=train_sampler, pin_memory=True)
    for epoch in range(cfg.resume_epoch, cfg.TRAIN.EPOCH):
        losses = 0.
        losses_color = 0.
        losses_types = 0.
        losses_nl_color = 0.
        losses_nl_types = 0.
        precs = 0.
        train_sampler.set_epoch(epoch)
        for idx, (nl, frame, label, act_map, color_label, type_label, nl_color_label, nl_type_label) in enumerate(loader):
            # print(nl.shape)
            # print(global_img.shape)
            # print(local_img.shape)
            nl = nl.cuda(non_blocking=True)
            label = label.cuda(non_blocking=True)
            act_map = act_map.cuda(non_blocking=True)
            # global_img, local_img = global_img.cuda(), local_img.cuda()
            nl = nl.transpose(1, 0)
            frame = frame.cuda(non_blocking=True)
            color_label = color_label.cuda(non_blocking=True)
            type_label = type_label.cuda(non_blocking=True)
            nl_color_label = nl_color_label.cuda(non_blocking=True)
            nl_type_label = nl_type_label.cuda(non_blocking=True)
            output, color, types, nl_color, nl_types = model(nl, frame, act_map)
            
            # loss = sampling_loss(output, label, ratio=5)
            # loss = F.binary_cross_entropy_with_logits(output, label)
            total_num_pos = reduce_sum(label.new_tensor([label.sum()])).item()
            num_pos_avg_per_gpu = max(total_num_pos / float(cfg.num_gpu), 1.0)

            loss = sigmoid_focal_loss(output, label, reduction='sum') / num_pos_avg_per_gpu
            loss_color = color_loss(color, color_label) * cfg.TRAIN.ALPHA_COLOR
            loss_type = vehicle_loss(types, type_label) * cfg.TRAIN.ALPHA_TYPE
            loss_nl_color = color_loss(nl_color, nl_color_label) * cfg.TRAIN.ALPHA_NL_COLOR
            loss_nl_type = vehicle_loss(nl_types, nl_type_label) * cfg.TRAIN.ALPHA_NL_TYPE
            loss_total = loss + loss_color + loss_type + loss_nl_color + loss_nl_type
            optimizer.zero_grad()
            loss_total.backward()
            # torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
            optimizer.step()
            
            losses += loss.item()
            losses_color += loss_color.item()
            losses_types += loss_type.item()
            losses_nl_color += loss_nl_color.item()
            losses_nl_types += loss_nl_type.item()
            # precs += recall.item()
            
            if rank == 0 and idx % cfg.TRAIN.PRINT_FREQ == 0:
                pred = (output.sigmoid() > 0.5)
                # print((pred == label).sum())
                pred = (pred == label) 
                recall = (pred * label).sum() / label.sum()
                ca = (color.argmax(dim=1) == color_label)
                ca = ca.sum().item() / ca.numel()
                ta = (types.argmax(dim=1) == type_label)
                ta = ta.sum().item() / ta.numel()
                # accu = pred.sum().item() / pred.numel()
                lr = optimizer.param_groups[0]['lr']
                print(f'epoch: {epoch},', 
                f'lr: {lr}, step: {idx}/{len(loader)},',
                f'loss: {losses / (idx + 1):.4f},', 
                f'loss color: {losses_color / (idx + 1):.4f},',
                f'loss type: {losses_types / (idx + 1):.4f},',
                f'loss nl color: {losses_nl_color / (idx + 1):.4f},',
                f'loss nl type: {losses_nl_types / (idx + 1):.4f},',
                f'recall: {recall.item():.4f}, c_accu: {ca:.4f}, t_accu: {ta:.4f}')
        lr_scheduler.step()
        if rank == 0:
            if not os.path.exists('save'):
                os.mkdir('save')
            torch.save(model.state_dict(), f'save/{epoch}.pth')
            torch.save(optimizer.state_dict(), f'save/{epoch}_optim.pth')