def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--corpus', type=Path, required=True)
    parser.add_argument('--train', type=Path, required=True)
    parser.add_argument('--dev', type=Path, required=True)
    parser.add_argument('--model-name', type=str, required=True)

    parser.add_argument('--ckpt', type=str, default='ckpt')
    parser.add_argument('--batch_size', type=int, default=8)
    parser.add_argument('--epochs', type=int, default=3)
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('-a', '--accumulation_steps', type=int, default=1)

    parser.add_argument('--fp16', action='store_true')
    # Automatically supplied by torch.distributed.launch
    parser.add_argument('--local_rank', type=int, default=0)
    args = parser.parse_args()

    logging.basicConfig(level=logging.INFO)

    # Based on SciFact transoformers training
    class SciFactLabelPredictionDataset(Dataset):
        def __init__(self, corpus, claims_file):

            claims, rationales, labels = self._read(corpus, claims_file)

            self._claims = claims
            self._rationales = rationales
            self._labels = labels

        def _read(self, corpus, claims_file):
            claims = []
            rationales = []
            labels = []

            corpus = {doc['doc_id']: doc for doc in jsonlines.open(corpus)}
            #label_encodings = {'CONTRADICT': 0, 'NOT_ENOUGH_INFO': 1, 'SUPPORT': 2} #From SciFact
            label_encodings = {
                'CONTRADICT': 1,
                'NOT_ENOUGH_INFO': 2,
                'SUPPORT': 0
            }  # To Match COVIDLies

            for claim in jsonlines.open(claims_file):

                if claim['evidence']:
                    for doc_id, evidence_sets in claim['evidence'].items():
                        doc = corpus[int(doc_id)]

                        # Add individual evidence set as samples:
                        for evidence_set in evidence_sets:
                            rationale = [
                                doc['abstract'][i].strip()
                                for i in evidence_set['sentences']
                            ]
                            claims.append(claim['claim'])
                            rationales.append(' '.join(rationale))
                            labels.append(
                                label_encodings[evidence_set['label']])

                        # Add all evidence sets as positive samples
                        rationale_idx = {
                            s
                            for es in evidence_sets for s in es['sentences']
                        }
                        rationale_sentences = [
                            doc['abstract'][i].strip()
                            for i in sorted(list(rationale_idx))
                        ]
                        claims.append(claim['claim'])
                        rationales.append(' '.join(rationale_sentences))
                        labels.append(
                            label_encodings[evidence_sets[0]['label']]
                        )  # directly use the first evidence set label
                        # because currently all evidence sets have
                        # the same label
                        # Add negative samples
                        non_rationale_idx = set(range(len(
                            doc['abstract']))) - rationale_idx
                        non_rationale_idx = random.sample(
                            non_rationale_idx,
                            k=min(random.randint(1, 2),
                                  len(non_rationale_idx)))
                        non_rationale_sentences = [
                            doc['abstract'][i].strip()
                            for i in sorted(list(non_rationale_idx))
                        ]
                        claims.append(claim['claim'])
                        rationales.append(' '.join(non_rationale_sentences))
                        labels.append(label_encodings['NOT_ENOUGH_INFO'])

                else:
                    # Add negative samples
                    for doc_id in claim['cited_doc_ids']:
                        doc = corpus[int(doc_id)]
                        non_rationale_idx = random.sample(
                            range(len(doc['abstract'])),
                            k=random.randint(1, 2))
                        non_rationale_sentences = [
                            doc['abstract'][i].strip()
                            for i in non_rationale_idx
                        ]
                        claims.append(claim['claim'])
                        rationales.append(' '.join(non_rationale_sentences))
                        labels.append(label_encodings['NOT_ENOUGH_INFO'])

            return claims, rationales, labels

        def __len__(self):
            return len(self._labels)

        def __getitem__(self, index):
            claim = self._claims[index]
            rationale = self._rationales[index]
            label = self._labels[index]
            return claim, rationale, label

    # Additional janky distributed stuff
    args.distributed = False
    world_size = int(os.environ.get('WORLD_SIZE', 1))
    args.distributed = world_size > 1
    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')

    logger.info('Loading training data')
    train_dataset = SciFactLabelPredictionDataset(args.corpus, args.train)
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  sampler=get_sampler(train_dataset,
                                                      world_size,
                                                      args.local_rank))

    logger.info('Loading dev data')
    dev_dataset = SciFactLabelPredictionDataset(args.corpus, args.dev)
    dev_dataloader = DataLoader(
        dev_dataset,
        batch_size=args.batch_size,
        sampler=get_sampler(dev_dataset, world_size, args.local_rank),
        shuffle=False  # Seems weird but the HuggingFace guys do it so...
    )

    model = SentenceBertClassifier(model_name=args.model_name,
                                   num_classes=3).cuda()
    optimizer = transformers.AdamW(model.parameters(), lr=args.lr)
    if args.fp16:
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
    if args.distributed:
        model = DistributedDataParallel(model)
    loss_fn = torch.nn.CrossEntropyLoss()  # Do we need to ignore padding?

    for epoch in range(args.epochs):
        logger.info(f'Epoch: {epoch}')

        logger.info('Training...')
        model.train()
        if args.local_rank == 0:
            iterable = tqdm(train_dataloader)
        else:
            iterable = train_dataloader
        for i, (claims, rationales, labels) in enumerate(iterable):
            if not i % args.accumulation_steps:
                optimizer.step()
                optimizer.zero_grad()

            logits = model(claims, rationales)
            _, preds = logits.max(dim=-1)
            labels = torch.tensor(labels).cuda()
            acc = (preds == labels).float().mean()
            loss = loss_fn(logits, labels)
            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            if args.local_rank == 0:
                iterable.set_description(
                    f'Loss: {loss : 0.4f} - Acc: {acc : 0.4f}')

        logger.info('Evaluating...')
        model.eval()
        correct = 0.
        total = 0.
        if args.local_rank == 0:
            iterable = tqdm(dev_dataloader)
        else:
            iterable = dev_dataloader
        for claims, rationales, labels in iterable:
            with torch.no_grad():
                logits = model(claims, rationales)
            _, preds = logits.max(dim=-1)
            labels = torch.tensor(labels).cuda()
            correct += (preds == labels).float().sum()
            total += labels.size(0)
            if args.local_rank == 0:
                acc = correct / total
                iterable.set_description(f'Accuracy: {acc.item() : 0.4f}')

        logger.info('Saving...')
        if args.local_rank == 0:
            torch.save(model.state_dict(), f'{args.ckpt}-{epoch}.pt')
Exemple #2
0
 def _configure_apex_amp(self, amp, models, optimizers, apex_args):
     models, optimizers = amp.initialize(models, optimizers, **apex_args)
     return models, optimizers
Exemple #3
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--train', type=Path, required=True)
    parser.add_argument('--dev', type=Path, required=True)
    parser.add_argument('--model-name', type=str, required=True)

    parser.add_argument('--ckpt', type=str, default='ckpt')
    parser.add_argument('--batch_size', type=int, default=8)
    parser.add_argument('--epochs', type=int, default=3)
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('-a', '--accumulation_steps', type=int, default=1)

    parser.add_argument('--fp16', action='store_true')
    # Automatically supplied by torch.distributed.launch
    parser.add_argument('--local_rank', type=int, default=0)
    args = parser.parse_args()

    logging.basicConfig(level=logging.INFO)

    # Additional janky distributed stuff
    args.distributed = False
    world_size = int(os.environ.get('WORLD_SIZE', 1))
    args.distributed = world_size > 1
    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl', init_method='env://')

    logger.info('Loading training data')
    train_dataset = NliDataset(args.train)
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        sampler=get_sampler(train_dataset, world_size, args.local_rank)
    )

    logger.info('Loading dev data')
    dev_dataset = NliDataset(args.dev)
    dev_dataloader = DataLoader(
        dev_dataset,
        batch_size=args.batch_size,
        sampler=get_sampler(dev_dataset, world_size, args.local_rank),
        shuffle=False  # Seems weird but the HuggingFace guys do it so...
    )

    model = SentenceBertClassifier(model_name=args.model_name, num_classes=3).cuda()
    optimizer = transformers.AdamW(model.parameters(), lr=args.lr)
    if args.fp16:
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
    if args.distributed:
        model = DistributedDataParallel(model)
    loss_fn = torch.nn.CrossEntropyLoss()  # Do we need to ignore padding?

    for epoch in range(args.epochs):
        logger.info(f'Epoch: {epoch}')

        logger.info('Training...')
        model.train()
        if args.local_rank == 0:
            iterable = tqdm(train_dataloader)
        else:
            iterable = train_dataloader
        for i, (premises, hypotheses, labels) in enumerate(iterable):
            if not i % args.accumulation_steps:
                optimizer.step()
                optimizer.zero_grad()
            logits = model(premises, hypotheses)
            _, preds = logits.max(dim=-1)
            labels = torch.tensor([LABEL_TO_IDX[l] for l in labels]).cuda()
            acc = (preds == labels).float().mean()
            loss = loss_fn(logits, labels)
            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            if args.local_rank == 0:
                iterable.set_description(f'Loss: {loss : 0.4f} - Acc: {acc : 0.4f}')

        logger.info('Evaluating...')
        model.eval()
        correct = 0.
        total = 0.
        if args.local_rank == 0:
            iterable = tqdm(dev_dataloader)
        else:
            iterable = dev_dataloader
        for premises, hypotheses, labels in iterable:
            with torch.no_grad():
                logits = model(premises, hypotheses)
            _, preds = logits.max(dim=-1)
            labels = torch.tensor([LABEL_TO_IDX[l] for l in labels]).cuda()
            correct += (preds == labels).float().sum()
            total += labels.size(0)
            if args.local_rank == 0:
                acc = correct / total
                iterable.set_description(f'Accuracy: {acc.item() : 0.4f}')

        logger.info('Saving...')
        if args.local_rank == 0:
            torch.save(model.state_dict(), f'{args.ckpt}-{epoch}.pt')
Exemple #4
0
def main():
    epoch_num = 30000
    batch_size_train = 64
    model_name = 'NUSNet'
    init_seeds(2 + batch_size_train)
    resume = False

    model = NUSNet(3, 1)  # input channels and output channels
    model_info(model, verbose=True)

    print("Using apex synced BN.")
    model = amp.parallel.convert_syncbn_model(model)
    model.cuda()
    # summary model
    # logging.info(summary(model, (3, 320, 320)))

    optimizer = optim.Adam(model.parameters(),
                           lr=1e-3,
                           betas=(0.9, 0.999),
                           eps=1e-8,
                           weight_decay=0)

    tra_image_dir = os.path.abspath(str(Path('train_data/TR-Image')))
    tra_label_dir = os.path.abspath(str(Path('train_data/TR-Mask')))
    saved_model_dir = os.path.join(os.getcwd(), 'saved_models' + os.sep)
    log_dir = os.path.join(os.getcwd(), 'saved_models',
                           model_name + '_Temp.pth')

    if not os.path.exists(saved_model_dir):
        os.makedirs(saved_model_dir, exist_ok=True)

    model, optimizer = amp.initialize(model,
                                      optimizer,
                                      opt_level='O1',
                                      verbosity=0)

    dist.init_process_group(
        backend='nccl',  # 'distributed backend'
        init_method='tcp://127.0.0.1:9999',  # distributed training init method
        world_size=1,  # number of nodes for distributed training
        rank=0)  # distributed training node rank

    model = torch.nn.parallel.DistributedDataParallel(
        model, find_unused_parameters=True)

    start_epoch = 0

    # If there is a saved model, load the model and continue training based on it
    if resume:
        check_file(log_dir)
        checkpoint = torch.load(log_dir, map_location=torch.device('cpu'))
        model.load_state_dict(checkpoint['model'], False)
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch']

    img_formats = ['.bmp', '.jpg', '.jpeg', '.png', '.tif', '.tiff', '.dng']

    images_files = sorted(glob.glob(os.path.join(tra_image_dir, '*.*')))
    labels_files = sorted(glob.glob(os.path.join(tra_label_dir, '*.*')))

    tra_img_name_list = [
        x for x in images_files
        if os.path.splitext(x)[-1].lower() in img_formats
    ]
    tra_lbl_name_list = [
        x for x in labels_files
        if os.path.splitext(x)[-1].lower() in img_formats
    ]

    logging.info(
        '================================================================')
    logging.info('train images numbers: %g' % len(tra_img_name_list))
    logging.info('train labels numbers: %g' % len(tra_lbl_name_list))

    assert len(tra_img_name_list) == len(
        tra_lbl_name_list
    ), 'The number of training images: %g  , the number of training labels: %g .' % (
        len(tra_img_name_list), len(tra_lbl_name_list))

    salobj_dataset = SalObjDataset(img_name_list=tra_img_name_list,
                                   lbl_name_list=tra_lbl_name_list,
                                   transform=transforms.Compose([
                                       RescaleT(400),
                                       RandomCrop(300),
                                       ToTensorLab(flag=0)
                                   ]))
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        salobj_dataset)
    salobj_dataloader = torch.utils.data.DataLoader(
        salobj_dataset,
        batch_size=batch_size_train,
        sampler=train_sampler,
        shuffle=False,
        num_workers=16,
        pin_memory=True)

    # training parameter
    ite_num = 0
    running_loss = 0.0  # total_loss = final_fusion_loss +sup1 +sup2 + sup3 + sup4 +sup5 +sup6
    running_tar_loss = 0.0  # final_fusion_loss

    for epoch in range(start_epoch, epoch_num):

        model.train()
        pbar = enumerate(salobj_dataloader)
        pbar = tqdm(pbar, total=len(salobj_dataloader))

        for i, data in pbar:
            ite_num = ite_num + 1

            inputs, labels = data['image'], data['label']
            inputs, labels = inputs.type(torch.FloatTensor), labels.type(
                torch.FloatTensor)
            inputs_v, labels_v = inputs.cuda(non_blocking=True), labels.cuda(
                non_blocking=True)

            # forward + backward + optimize
            final_fusion_loss, sup1, sup2, sup3, sup4, sup5, sup6 = model(
                inputs_v)
            final_fusion_loss_mblf, total_loss = muti_bce_loss_fusion(
                final_fusion_loss, sup1, sup2, sup3, sup4, sup5, sup6,
                labels_v)

            optimizer.zero_grad()

            with amp.scale_loss(total_loss, optimizer) as scaled_loss:
                scaled_loss.backward()

            optimizer.step()

            # # print statistics
            running_loss += total_loss.item()
            running_tar_loss += final_fusion_loss_mblf.item()

            # del temporary outputs and loss
            del final_fusion_loss, sup1, sup2, sup3, sup4, sup5, sup6, final_fusion_loss_mblf, total_loss

            s = ('%10s' + '%-15s' + '%10s' + '%-15s' + '%10s' + '%-10d' +
                 '%20s' + '%-10.4f' + '%20s' + '%-10.4f') % (
                     'Epoch: ', '%g/%g' %
                     (epoch + 1, epoch_num), 'Batch: ', '%g/%g' %
                     ((i + 1) * batch_size_train, len(tra_img_name_list)),
                     'Iteration: ', ite_num, 'Total_loss: ',
                     running_loss / ite_num, 'Final_fusion_loss: ',
                     running_tar_loss / ite_num)
            pbar.set_description(s)

        # The model is saved every 50 epoch
        if (epoch + 1) % 50 == 0:
            state = {
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'epoch': epoch + 1
            }
            torch.save(state, saved_model_dir + model_name + ".pth")

    torch.save(model.state_dict(), saved_model_dir + model_name + ".pth")

    # if dist.get_rank() == 0:
    #     torch.save(model.module.state_dict(), saved_model_dir + model_name + ".pth")
    torch.cuda.empty_cache()
def train(hyp):
    cfg = opt.cfg
    # data = opt.data
    epochs = opt.epochs  # 500200 batches at bs 64, 117263 images = 273 epochs
    batch_size = opt.batch_size
    accumulate = max(round(64 / batch_size),
                     1)  # accumulate n times before optimizer update (bs 64)
    weights = opt.weights  # initial training weights
    # Image Sizes
    gs = 32  # (pixels) grid size max stride
    # Configure run
    rank = opt.global_rank
    init_seeds(2 + rank)
    nc = 1 if opt.single_cls else int(len(open(
        opt.names_classes).readlines()))  # number of classes
    hyp['cls'] *= nc / 80  # update coco-tuned hyp['cls'] to current dataset

    # Remove previous results
    try:
        for f in glob.glob('*_batch*.jpg') + glob.glob(results_file):
            os.remove(f)
        print("success remove last train_batch*.jpg")
    except:
        print("no last train_-batch*.jpg")
    # Initialize model
    model = Darknet(opt.cfg, opt.input_size, opt.algorithm_type).to(device)
    cuda = device.type != 'cpu'
    # Optimizer
    pg0, pg1, pg2 = [], [], []  # optimizer parameter groups
    for k, v in dict(model.named_parameters()).items():
        if '.bias' in k:
            pg2 += [v]  # biases
        elif 'Conv2d.weight' in k:
            pg1 += [v]  # apply weight_decay
        else:
            pg0 += [v]  # all else

    if opt.adam:
        # hyp['lr0'] *= 0.1  # reduce lr (i.e. SGD=5E-3, Adam=5E-4)
        optimizer = optim.Adam(pg0, lr=hyp['lr0'])
        # optimizer = AdaBound(pg0, lr=hyp['lr0'], final_lr=0.1)
    else:
        optimizer = optim.SGD(pg0,
                              lr=hyp['lr0'],
                              momentum=hyp['momentum'],
                              nesterov=True)
    optimizer.add_param_group({
        'params': pg1,
        'weight_decay': hyp['weight_decay']
    })  # add pg1 with weight_decay
    optimizer.add_param_group({'params': pg2})  # add pg2 (biases)
    print('Optimizer groups: %g .bias, %g Conv2d.weight, %g other' %
          (len(pg2), len(pg1), len(pg0)))
    del pg0, pg1, pg2

    start_epoch = 0
    best_fitness = 0.0
    # attempt_download(weights)
    if weights.endswith('.pt'):  # pytorch format
        # possible weights are '*.pt', 'yolov3-spp.pt', 'yolov3-tiny.pt' etc.
        ckpt = torch.load(weights, map_location=device)

        # load model
        try:
            ckpt['model'] = {
                k: v
                for k, v in ckpt['model'].state_dict().items()
                if model.state_dict()[k].numel() == v.numel()
            }
            model.load_state_dict(ckpt['model'], strict=False)
        except KeyError as e:
            s = "%s is not compatible with %s. Specify --weights '' or specify a --cfg compatible with %s. " \
                "See https://github.com/ultralytics/yolov3/issues/657" % (opt.weights, opt.cfg, opt.weights)
            raise KeyError(s) from e

        # load optimizer
        if ckpt['optimizer'] is not None:
            optimizer.load_state_dict(ckpt['optimizer'])
            best_fitness = ckpt['best_fitness']

        # load results
        if ckpt.get('training_results') is not None:
            with open(results_file, 'w') as file:
                file.write(ckpt['training_results'])  # write results.txt

        # epochs
        start_epoch = ckpt['epoch'] + 1
        if epochs < start_epoch:
            print(
                '%s has been trained for %g epochs. Fine-tuning for %g additional epochs.'
                % (opt.weights, ckpt['epoch'], epochs))
            epochs += ckpt['epoch']  # finetune additional epochs

        del ckpt

    elif len(weights) > 0:  # darknet format
        # possible weights are '*.weights', 'yolov3-tiny.conv.15',  'darknet53.conv.74' etc.
        load_darknet_weights(model, weights)

    if opt.freeze_layers:
        output_layer_indices = [
            idx - 1 for idx, module in enumerate(model.module_list)
            if isinstance(module, YOLOLayer)
        ]
        freeze_layer_indices = [
            x for x in range(len(model.module_list))
            if (x not in output_layer_indices) and (
                x - 1 not in output_layer_indices)
        ]
        for idx in freeze_layer_indices:
            for parameter in model.module_list[idx].parameters():
                parameter.requires_grad_(False)

    # Mixed precision training https://github.com/NVIDIA/apex
    if mixed_precision:
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level='O1',
                                          verbosity=0)

    # Scheduler https://arxiv.org/pdf/1812.01187.pdf
    lf = lambda x: ((
        (1 + math.cos(x * math.pi / epochs)) / 2)**1.0) * 0.95 + 0.05  # cosine
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
    scheduler.last_epoch = start_epoch - 1  # see link below
    scaler = amp.GradScaler(enabled=cuda)
    # https://discuss.pytorch.org/t/a-problem-occured-when-resuming-an-optimizer/28822

    # Plot lr schedule
    # y = []
    # for _ in range(epochs):
    #     scheduler.step()
    #     y.append(optimizer.param_groups[0]['lr'])
    # plt.plot(y, '.-', label='LambdaLR')
    # plt.xlabel('epoch')
    # plt.ylabel('LR')
    # plt.tight_layout()
    # plt.savefig('LR.png', dpi=300)

    # Initialize distributed training
    # if device.type != 'cpu' and torch.cuda.device_count() > 1 and torch.distributed.is_available():
    #     dist.init_process_group(backend='nccl',  # 'distributed backend'
    #                             init_method='tcp://127.0.0.1:9995',  # distributed training init method
    #                             world_size=1,  # number of nodes for distributed training
    #                             rank=0)  # distributed training node rank
    #     model = torch.nn.parallel.DistributedDataParallel(model)
    # model = torch.nn.DataParallel(model).to(device)
    # model.module_list = model.module.module_list
    print("rank is", opt.global_rank)
    ema = torch_utils.ModelEMA(model) if rank in [-1, 0] else None
    if rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[opt.local_rank], output_device=(opt.local_rank))
    model.yolo_layers = model.module.yolo_layers  # move yolo layer indices to top level
    dataloader, dataset = create_dataloader(
        opt.train_path,
        opt.input_size,
        batch_size,
        gs,
        hyp=hyp,
        augment=True,
        cache=False,
        rect=False,
        local_rank=rank,  # Model parameters
        world_size=opt.world_size)
    nb = 64
    if rank in [-1, 0]:
        ema.updates = start_epoch * nb // accumulate  # set EMA updates ***
        # local_rank is set to -1. Because only the first process is expected to do evaluation.
        testloader = create_dataloader(opt.val_path,
                                       opt.input_size,
                                       4,
                                       gs,
                                       hyp=hyp,
                                       augment=False,
                                       cache=False,
                                       rect=True,
                                       local_rank=-1,
                                       world_size=1)[0]
    nw = 8
    model.nc = nc  # attach number of classes to model
    model.hyp = hyp  # attach hyperparameters to model
    model.gr = 1.0  # giou loss ratio (obj_loss = 1.0 or giou)
    model.class_weights = labels_to_class_weights(dataset.labels, nc).to(
        device)  # attach class weights

    # Model EMA

    # Start training
    nb = len(dataloader)  # number of batches
    n_burn = max(3 * nb,
                 500)  # burn-in iterations, max(3 epochs, 500 iterations)
    maps = np.zeros(nc)  # mAP per class
    # torch.autograd.set_detect_anomaly(True)
    results = (
        0, 0, 0, 0, 0, 0, 0
    )  # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification'
    t0 = time.time()
    # print('Image sizes %g - %g train, %g test' % (imgsz_min, imgsz_max, imgsz_test))
    print('Using %g dataloader workers' % nw)
    print('Starting training for %g epochs...' % epochs)
    for epoch in range(
            start_epoch, epochs
    ):  # epoch ------------------------------------------------------------------
        model.train()
        # Update image weights (optional)
        if dataset.image_weights:
            if rank in [-1, 0]:
                w = model.class_weights.cpu().numpy() * (
                    1 - maps)**2  # class weights
                image_weights = labels_to_image_weights(dataset.labels,
                                                        nc=nc,
                                                        class_weights=w)
                dataset.indices = random.choices(
                    range(dataset.n), weights=image_weights,
                    k=dataset.n)  # rand weighted idx
                # Broadcast if DDP
            if rank != -1:
                indices = torch.zeros([dataset.n], dtype=torch.int)
                if rank == 0:
                    indices[:] = torch.from_tensor(dataset.indices,
                                                   dtype=torch.int)
                dist.broadcast(indices, 0)
                if rank != 0:
                    dataset.indices = indices.cpu().numpy()
        mloss = torch.zeros(4).to(device)  # mean losses
        print(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'GIoU', 'obj', 'cls',
                                     'total', 'targets', 'img_size'))
        if rank != -1:
            dataloader.sampler.set_epoch(epoch)
        pbar = enumerate(dataloader)
        optimizer.zero_grad()
        if opt.local_rank in [-1, 0]:
            print(
                ('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'GIoU', 'obj',
                                       'cls', 'total', 'targets', 'img_size'))
            # pbar = tqdm(pbar, total=nb)  # progress bar
            pbar = tqdm(enumerate(dataloader), total=nb)  # progress bar
        for i, (
                imgs, targets, paths, _
        ) in pbar:  # batch -------------------------------------------------------------
            ni = i + nb * epoch  # number integrated batches (since train start)
            imgs = imgs.to(device).float(
            ) / 255.0  # uint8 to float32, 0 - 255 to 0.0 - 1.0
            if opt.multi_scale:
                sz = random.randrange(
                    opt.input_size * 0.5,
                    opt.input_size * 1.5 + gs) // gs * gs  # size
                sf = sz / max(imgs.shape[2:])  # scale factor
                if sf != 1:
                    ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]
                          ]  # new shape (stretched to gs-multiple)
                    imgs = F.interpolate(imgs,
                                         size=ns,
                                         mode='bilinear',
                                         align_corners=False)
            targets = targets.to(device)

            # Burn-in
            if ni <= n_burn:
                xi = [0, n_burn]  # x interp
                model.gr = np.interp(
                    ni, xi,
                    [0.0, 1.0])  # giou loss ratio (obj_loss = 1.0 or giou)
                accumulate = max(
                    1,
                    np.interp(ni, xi, [1, 64 / batch_size]).round())
                for j, x in enumerate(optimizer.param_groups):
                    # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
                    x['lr'] = np.interp(
                        ni, xi,
                        [0.1 if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
                    x['weight_decay'] = np.interp(
                        ni, xi, [0.0, hyp['weight_decay'] if j == 1 else 0.0])
                    if 'momentum' in x:
                        x['momentum'] = np.interp(ni, xi,
                                                  [0.9, hyp['momentum']])
            # Forward
            pred = model(imgs)

            # Loss
            loss, loss_items = compute_loss(pred, targets, model)
            if rank != -1:
                loss *= opt.world_size  # gradient averaged between devices in DDP mode
            if not torch.isfinite(loss):
                print('WARNING: non-finite loss, ending training ', loss_items)
                return results
            # Backward
            # loss *= batch_size / 64  # scale loss
            # if mixed_precision:
            #     with amp.scale_loss(loss, optimizer) as scaled_loss:
            #         scaled_loss.backward()
            # else:
            #     loss.backward()
            # Optimize
            scaler.scale(loss).backward()
            if ni % accumulate == 0:
                scaler.step(optimizer)  # optimizer.step
                scaler.update()
                optimizer.step()
                optimizer.zero_grad()
                # ema.update(model)
                if ema is not None:
                    ema.update(model)
            # Print
            if rank in [-1, 0]:
                mloss = (mloss * i + loss_items) / (i + 1
                                                    )  # update mean losses
                mem = '%.3gG' % (torch.cuda.memory_cached() / 1E9
                                 if torch.cuda.is_available() else 0)  # (GB)
                s = ('%10s' * 2 +
                     '%10.3g' * 6) % ('%g/%g' % (epoch, epochs - 1), mem,
                                      *mloss, len(targets), imgs.shape[-1])
                pbar.set_description(s)

                # Plot
                if ni < 1:
                    f = 'train_batch%g.jpg' % i  # filename
                    res = plot_images(images=imgs,
                                      targets=targets,
                                      paths=paths,
                                      fname=f)
                    if tb_writer:
                        tb_writer.add_image(f,
                                            res,
                                            dataformats='HWC',
                                            global_step=epoch)
                        # tb_writer.add_graph(model, imgs)  # add model to tensorboard

            # end batch ------------------------------------------------------------------------------------------------
        # Update scheduler
        scheduler.step()
        # Process epoch results
        if rank in [-1, 0]:
            ema.update_attr(model)
            final_epoch = epoch + 1 == epochs
            if not opt.notest or final_epoch:  # Calculate mAP
                results, maps, times = test.test(
                    cfg=opt.cfg,
                    names_file=opt.names_classes,
                    batch_size=16,
                    img_size=opt.input_size,
                    conf_thres=0.01,
                    save_json=False,
                    # model=ema.ema.module if hasattr(ema.ema, 'module') else ema.ema,
                    model=ema.ema,
                    single_cls=False,
                    dataloader=testloader,
                    save_dir=wdir)
            # Write
            with open(results_file, 'a') as f:
                f.write(s + '%10.3g' * 7 % results +
                        '\n')  # P, R, mAP, F1, test_losses=(GIoU, obj, cls)
            if len(opt.name) and opt.bucket:
                os.system(
                    'gsutil cp results.txt gs://%s/results/results%s.txt' %
                    (opt.bucket, opt.name))

            # Tensorboard
            if tb_writer:
                tags = [
                    'train/giou_loss', 'train/obj_loss', 'train/cls_loss',
                    'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5',
                    'metrics/F1', 'val/giou_loss', 'val/obj_loss',
                    'val/cls_loss'
                ]
                for x, tag in zip(list(mloss[:-1]) + list(results), tags):
                    tb_writer.add_scalar(tag, x, epoch)

            # Update best mAP
            fi = fitness(np.array(results).reshape(
                1, -1))  # fitness_i = weighted combination of [P, R, mAP, F1]
            if fi > best_fitness:
                best_fitness = fi

            # Save model
            save = (not opt.nosave) or (final_epoch and not opt.evolve)
            if save:
                with open(results_file, 'r') as f:  # create checkpoint
                    ckpt = {
                        'epoch':
                        epoch,
                        'best_fitness':
                        best_fitness,
                        'training_results':
                        f.read(),
                        # 'model': ema.ema.module.state_dict() if hasattr(model, 'module') else ema.ema.state_dict(),
                        'model':
                        ema.ema.module if hasattr(ema, 'module') else ema.ema,
                        'optimizer':
                        None if final_epoch else optimizer.state_dict()
                    }

                # Save last, best and delete
                torch.save(ckpt, last)
                if (best_fitness == fi) and not final_epoch:
                    torch.save(ckpt, best)
                del ckpt

        # end epoch ----------------------------------------------------------------------------------------------------
    # end training
    if rank in [-1, 0]:
        n = opt.name
        if len(n):
            n = '_' + n if not n.isnumeric() else n
            fresults, flast, fbest = 'results%s.txt' % n, wdir + 'last%s.pt' % n, wdir + 'best%s.pt' % n
            for f1, f2 in zip(
                [wdir + 'last.pt', wdir + 'best.pt', 'results.txt'],
                [flast, fbest, fresults]):
                if os.path.exists(f1):
                    os.rename(f1, f2)  # rename
                    ispt = f2.endswith('.pt')  # is *.pt
                    strip_optimizer(f2) if ispt else None  # strip optimizer
                    os.system('gsutil cp %s gs://%s/weights' % (f2, opt.bucket)
                              ) if opt.bucket and ispt else None  # upload

        if not opt.evolve:
            plot_results()  # save as results.png
        print('%g epochs completed in %.3f hours.\n' %
              (epoch - start_epoch + 1, (time.time() - t0) / 3600))
    dist.destroy_process_group() if torch.cuda.device_count() > 1 else None
    torch.cuda.empty_cache()
    return results