예제 #1
0
def create_optimizer(net, name, learning_rate, weight_decay, momentum=0, fp16_loss_scale=None,
                     optimizer_state=None, device=None):
    net.float()

    use_fp16 = fp16_loss_scale is not None
    if use_fp16:
        from apex import fp16_utils
        net = fp16_utils.network_to_half(net)

    device = choose_device(device)
    print('use', device)
    if device.type == 'cuda':
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = True
    net = net.to(device)

    # optimizer
    parameters = [p for p in net.parameters() if p.requires_grad]
    print('N of parameters', len(parameters))

    if name == 'sgd':
        optimizer = optim.SGD(parameters, lr=learning_rate, momentum=momentum, weight_decay=weight_decay)
    elif name == 'adamw':
        from .adamw import AdamW
        optimizer = AdamW(parameters, lr=learning_rate, weight_decay=weight_decay)
    elif name == 'adam':
        optimizer = optim.Adam(parameters, lr=learning_rate, weight_decay=weight_decay)
    else:
        raise NotImplementedError(name)

    if use_fp16:
        from apex import fp16_utils
        if fp16_loss_scale == 0:
            opt_args = dict(dynamic_loss_scale=True)
        else:
            opt_args = dict(static_loss_scale=fp16_loss_scale)
        print('FP16_Optimizer', opt_args)
        optimizer = fp16_utils.FP16_Optimizer(optimizer, **opt_args)
    else:
        optimizer.backward = lambda loss: loss.backward()

    if optimizer_state:
        if use_fp16 and 'optimizer_state_dict' not in optimizer_state:
            # resume FP16_Optimizer.optimizer only
            optimizer.optimizer.load_state_dict(optimizer_state)
        elif use_fp16 and 'optimizer_state_dict' in optimizer_state:
            # resume optimizer from FP16_Optimizer.optimizer
            optimizer.load_state_dict(optimizer_state['optimizer_state_dict'])
        else:
            optimizer.load_state_dict(optimizer_state)

    return net, optimizer
예제 #2
0
파일: trainer.py 프로젝트: aihill/kaggle-1
    def __init__(self,
                 model: torch.nn.Module,
                 criterion: torch.nn.Module,
                 optimizer: torch.optim.Optimizer,
                 metrics: Dict[str, Metric],
                 data_loaders: AttrDict,
                 max_norm: float = None,
                 norm_type: int = 2,
                 scheduler: torch.optim.lr_scheduler._LRScheduler = None,
                 is_iteration_scheduler: bool = False,
                 device: torch.cuda.device = None,
                 mixed_precision: bool = False,
                 backup_path: str = None,
                 name: str = 'trainer',
                 logger: Logger = None,
                 finished_epochs: int = 0):
        if mixed_precision:
            model = network_to_half(model)

        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer
        self.metrics = metrics
        self.data_loaders = data_loaders
        self.max_norm = max_norm
        self.norm_type = norm_type
        self.scheduler = scheduler
        self.is_iteration_scheduler = is_iteration_scheduler
        self.device = device
        self.backup_path = backup_path
        self.finished_epochs = finished_epochs
        self.finished_iterations = finished_epochs * len(data_loaders.train)

        self.name = name
        self.logger = logger or get_logger()
        self._progress_bar = None
        self._description = 'ITERATION - loss: {:.3f}'

        self._trainer = create_supervised_trainer(
            model=model,
            optimizer=optimizer,
            loss_fn=criterion,
            max_norm=max_norm,
            norm_type=norm_type,
            device=device,
            mixed_precision=mixed_precision)
        self._register_handlers(self._trainer)
        self._evaluator = create_supervised_evaluator(model, metrics, device)
        self._epoch = 0
        self._iteration = 0
        self._train_loss = ExponentialMovingAverage()
예제 #3
0
def init_model():
    """
    Initialize resnet50 similarly to "ImageNet in 1hr" paper
        Batch norm moving average "momentum" <-- 0.9
        Fully connected layer <-- Gaussian weights (mean=0, std=0.01)
        gamma of last Batch norm layer of each residual block <-- 0
    """
    model = models.resnet50()
    for m in model.modules():
        if isinstance(m, Bottleneck):
            num_features = m.bn3.num_features
            m.bn3.weight = Parameter(torch.zeros(num_features))
    model.fc.weight.data.normal_(0, 0.01)
    model.cuda()
    if args.fp16 and not args.amp:
        model = network_to_half(model)
    return model
예제 #4
0
def train(args, model, dataset, name, is_warmup=False):
    import time

    import MixedPrecision.tools.utils as utils
    from MixedPrecision.tools.optimizer import OptimizerAdapter
    from MixedPrecision.tools.stats import StatStream
    from MixedPrecision.tools.monitor import make_monitor

    model = utils.enable_cuda(model)

    if args.half:
        from apex.fp16_utils import network_to_half
        model = network_to_half(model)

    criterion = utils.enable_cuda(nn.CrossEntropyLoss())
    # No Half precision for the criterion
    # criterion = utils.enable_half(criterion)

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    optimizer = OptimizerAdapter(optimizer,
                                 half=args.half,
                                 static_loss_scale=args.static_loss_scale,
                                 dynamic_loss_scale=args.dynamic_loss_scale)
    model.train()

    epoch_compute = StatStream(drop_first_obs=10)
    batch_compute = StatStream(drop_first_obs=10)
    gpu_compute = StatStream(drop_first_obs=10)
    compute_speed = StatStream(drop_first_obs=10)
    effective_speed = StatStream(drop_first_obs=10)
    data_waiting = StatStream(drop_first_obs=10)
    data_loading_gpu = StatStream(drop_first_obs=10)
    data_loading_cpu = StatStream(drop_first_obs=10)
    full_time = StatStream(drop_first_obs=10)
    iowait = StatStream(drop_first_obs=10)
    transfert_time = StatStream(drop_first_obs=10)

    start_event = torch.cuda.Event(enable_timing=True,
                                   blocking=False,
                                   interprocess=False)
    end_event = torch.cuda.Event(enable_timing=True,
                                 blocking=False,
                                 interprocess=False)

    floss = float('inf')

    # Stop after n print when benchmarking (n * batch_count) batch
    print_count = 0
    monitor_proc, gpu_monitor = make_monitor(loop_interval=250)

    def should_run():
        if args.prof is None:
            return True
        return print_count < args.prof

    try:
        for epoch in range(0, args.epochs):
            epoch_compute_start = time.time()

            # Looks like it only compute for the current process and not the children
            data_time_start = time.time()

            batch_count = 0
            effective_batch = 0

            for index, (x, y) in enumerate(dataset):
                transfert_start = time.time()
                x = x.cuda()
                y = y.cuda().long()
                torch.cuda.synchronize()

                data_time_end = time.time()
                transfert_time += (data_time_end - transfert_start)
                data_waiting += (data_time_end - data_time_start)

                # compute output
                batch_compute_start = time.time()

                output = model(x)
                loss = criterion(output, y)
                floss = loss.item()

                # compute gradient and do SGD step
                optimizer.zero_grad()
                optimizer.backward(loss)
                optimizer.step()

                #print(floss)
                torch.cuda.synchronize()
                batch_compute_end = time.time()
                full_time += batch_compute_end - data_time_start
                batch_compute += batch_compute_end - batch_compute_start

                compute_speed += args.batch_size / (batch_compute_end -
                                                    batch_compute_start)
                effective_speed += args.batch_size / (batch_compute_end -
                                                      data_time_start)

                effective_batch += 1

                data_time_start = time.time()

                batch_count += 1

                if effective_batch % 10 == 0:
                    print_count += 1
                    speed_avg = args.batch_size / batch_compute.avg

                    print(
                        '[{:4d}][{:4d}] '
                        'Batch Time (avg: {batch_compute.avg:.4f}, sd: {batch_compute.sd:.4f}) '
                        'Speed (avg: {speed:.4f}) '
                        'Data (avg: {data_waiting.avg:.4f}, sd: {data_waiting.sd:.4f})'
                        .format(1 + epoch,
                                batch_count,
                                batch_compute=batch_compute,
                                speed=speed_avg,
                                data_waiting=data_waiting))

                epoch_compute_end = time.time()
                epoch_compute.update(epoch_compute_end - epoch_compute_start)

                if not should_run():
                    break
            if not should_run():
                break
    finally:
        print('Done')
        gpu_monitor.stop()
        monitor_proc.terminate()

    if not is_warmup:
        hostname = socket.gethostname()
        current_device = torch.cuda.current_device()
        gpu = torch.cuda.get_device_name(current_device)
        gpu = gpu[0:min(10, len(gpu))].strip()

        bs = args.batch_size
        loader = args.loader

        header = [
            'Metric', 'Average', 'Deviation', 'Min', 'Max', 'count', 'half',
            'batch', 'workers', 'loader', 'model', 'hostname', 'GPU'
        ]
        common = [
            args.half, args.batch_size, args.workers, loader, name, hostname,
            gpu
        ]

        report_data = [
            ['Waiting for data (s)'] + data_waiting.to_array() + common,
            ['GPU Compute Time (s)'] + batch_compute.to_array() + common,
            ['Full Batch Time (s)'] + full_time.to_array() + common,
            [
                'Compute Speed (img/s)', bs / batch_compute.avg, 'NA', bs /
                batch_compute.max, bs / batch_compute.min, batch_compute.count
            ] + common,
            [
                'Effective Speed (img/s)', bs / full_time.avg, 'NA',
                bs / full_time.max, bs / full_time.min, batch_compute.count
            ] + common,
            # Ignored Metric
            #  GPU timed on the CPU side (very close to GPU timing anway)
            # # ['CPU Compute Time (s)] + batch_compute.to_array() + common,

            #  https://en.wikipedia.org/wiki/Harmonic_mean
            # ['Compute Inst Speed (img/s)'] + compute_speed.to_array() + common,
            # ['Effective Inst Speed (img/s)'] + effective_speed.to_array() + common,

            # ['iowait'] + iowait.to_array() + common
        ]

        # Only some loaders support this
        # So try and print an error but do not fail
        try:
            data_reading = dataset.dataset.read_timer
            data_transform = dataset.dataset.transform_timer
            collate_time = utils.timed_fast_collate.time_stream

            if data_loading_cpu.count > 1:
                report_data += [['Prefetch CPU Data loading (s)'] +
                                data_loading_cpu.to_array() + common]
                report_data += [['Prefetch GPU Data Loading (s)'] +
                                data_loading_gpu.to_array() + common]

            report_data += [['Read Time (s)'] + data_reading.to_array() +
                            common]
            report_data += [['Transform Time (s)'] +
                            data_transform.to_array() + common]
            report_data += [[
                'Read Speed per process (img/s)', 1.0 / data_reading.avg, 'NA',
                1.0 / data_reading.max, 1.0 / data_reading.min,
                data_reading.count
            ] + common]
            report_data += [[
                'Transform Speed per process  (img/s)',
                1.0 / data_transform.avg, 'NA', 1.0 / data_transform.max,
                1.0 / data_transform.min, data_transform.count
            ] + common]

            report_data += [[
                'Read Speed (img/s)', args.workers / data_reading.avg, 'NA',
                args.workers / data_reading.max,
                args.workers / data_reading.min, data_reading.count
            ] + common]
            report_data += [[
                'Transform Speed (img/s)', args.workers / data_transform.avg,
                'NA', args.workers / data_transform.max,
                args.workers / data_transform.min, data_transform.count
            ] + common]
            report_data += [[
                'Image Aggregation Speed (img/s)', bs / collate_time.avg, 'NA',
                bs / collate_time.max, bs / collate_time.min,
                collate_time.count
            ] + common]
            report_data += [[
                'Image Aggregation Time (s)', collate_time.avg,
                collate_time.sd, collate_time.max, collate_time.min,
                collate_time.count
            ] + common]
        except Exception as e:
            print(e)

        report_data.extend(gpu_monitor.arrays(common))
        report.print_table(header, report_data, filename=args.report)

    return
예제 #5
0
    args = parser.parse_args()

    torch.set_num_threads(args.workers)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    utils.set_use_gpu(args.gpu, not args.no_bench_mode)
    utils.set_use_half(args.half)
    utils.show_args(args)

    data_loader = load_dataset(args, train=True)

    model = utils.enable_cuda(HybridClassifier())

    if args.half:
        model = network_to_half(model)

    criterion = utils.enable_cuda(HybridLoss())

    optimizer = torch.optim.SGD(
        model.parameters(),
        args.lr,
        momentum=args.momentum,
        weight_decay=args.weight_decay
    )
    optimizer = OptimizerAdapter(
        optimizer,
        half=args.half,
        static_loss_scale=args.static_loss_scale,
        dynamic_loss_scale=args.dynamic_loss_scale
    )
예제 #6
0
def benchmark_training(model, opts):
    """Benchmarks training phase.

    :param obj model: A model to benchmark
    :param dict opts: A dictionary of parameters.
    :rtype: tuple:
    :return: A tuple of (model_name, list of batch times)
    """
    def _reduce_tensor(tensor):
        reduced = tensor.clone()
        dist.all_reduce(reduced, op=dist.reduce_op.SUM)
        reduced /= opts['world_size']
        return reduced

    if opts['phase'] != 'training':
        raise "Phase in benchmark_training func is '%s'" % opts['phase']

    opts['distributed'] = opts['world_size'] > 1
    opts['with_cuda'] = opts['device'] == 'gpu'
    opts['fp16'] = opts['dtype'] == 'float16'
    opts['loss_scale'] = 1

    if opts['fp16'] and not opts['with_cuda']:
        raise ValueError(
            "Configuration error: FP16 can only be used with GPUs")

    if opts['with_cuda']:
        torch.cuda.set_device(opts['local_rank'])
        cudnn.benchmark = opts['cudnn_benchmark']
        cudnn.fastest = opts['cudnn_fastest']

    if opts['distributed']:
        dist.init_process_group(backend=opts['dist_backend'],
                                init_method='env://')

    if opts['with_cuda']:
        model = model.cuda()
        if opts['dtype'] == 'float16':
            model = network_to_half(model)

    if opts['distributed']:
        model = DDP(model, shared_param=True)

    if opts['fp16']:
        model_params, master_params = prep_param_lists(model)
    else:
        master_params = list(model.parameters())

    criterion = nn.CrossEntropyLoss()
    if opts['with_cuda']:
        criterion = criterion.cuda()
    optimizer = optim.SGD(master_params,
                          lr=0.01,
                          momentum=0.9,
                          weight_decay=1e-4)

    data_loader = DatasetFactory.get_data_loader(opts, opts['__input_shape'],
                                                 opts['__num_classes'])

    is_warmup = opts['num_warmup_batches'] > 0
    done = opts['num_warmup_batches'] == 0
    num_iterations_done = 0
    model.train()
    batch_times = np.zeros(opts['num_batches'])
    end_time = timeit.default_timer()
    while not done:
        prefetcher = DataPrefetcher(data_loader, opts)
        batch_data, batch_labels = prefetcher.next()
        while batch_data is not None:
            data_var = torch.autograd.Variable(batch_data)
            labels_var = torch.autograd.Variable(batch_labels)

            output = model(data_var)

            loss = criterion(output, labels_var)
            loss = loss * opts['loss_scale']
            # I'll need this for reporting
            #reduced_loss = _reduce_tensor(loss.data) if opts['distributed'] else loss.data

            if opts['fp16']:
                model.zero_grad()
                loss.backward()
                model_grads_to_master_grads(model_params, master_params)
                if opts['loss_scale'] != 1:
                    for param in master_params:
                        param.grad.data = param.grad.data / opts['loss_scale']
                optimizer.step()
                master_params_to_model_params(model_params, master_params)
            else:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            if opts['with_cuda']:
                torch.cuda.synchronize()

            # Track progress
            num_iterations_done += 1
            cur_time = timeit.default_timer()

            batch_data, batch_labels = prefetcher.next()

            if is_warmup:
                if num_iterations_done >= opts['num_warmup_batches']:
                    is_warmup = False
                    num_iterations_done = 0
            else:
                if opts['num_batches'] != 0:
                    batch_times[num_iterations_done - 1] = cur_time - end_time
                if num_iterations_done >= opts['num_batches']:
                    done = True
                    break
            end_time = cur_time

    return (opts['__name'], batch_times)
예제 #7
0
def main():
    if args.fp16:
        assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled."
    
    txt_path=os.path.join(args.data,'file_list.txt')
    file1 = open(txt_path,"w")
    image_ids=[]
    
    df=pd.read_csv('test.csv',index_col=0)
    ids=df.index.tolist()
    
    for id in ids:
        file1.write(id[0]+'/'+id[1]+'/'+id[2]+'/'+id+'.jpg 0\n')
        image_ids.append(id)
    file1.close()
    
    crop_size =288
    val_size = 288
    
    
    pipe = HybridValPipe(batch_size=args.batch_size, num_threads=args.workers, device_id=args.local_rank, 
                         data_dir=args.data, crop=crop_size, size=val_size, file_list=txt_path)
    pipe.build()
    dataloader = DALIClassificationIterator(pipe, size=int(pipe.epoch_size("Reader") / args.world_size))
    
     
    if torch.cuda.is_available():
        device = torch.device("cuda:0")
        torch.cuda.set_device(device)
    else:
        device = torch.device("cpu")
    
    
    if args.arch in model_names:
        model=[None]*len(PRIMES)
        for i,p in enumerate(PRIMES):
            model[i]=models.__dict__[args.arch](num_classes=p)
            if args.checkpoint:
                model[i].load_state_dict(torch.load(args.checkpoint,
                    map_location=lambda storage, loc: storage.cuda(args.gpu))['state_'+str(p)])
            if torch.cuda.is_available():
                model[i] = model[i].cuda(device)
                if args.fp16:
                    model[i] = network_to_half(model[i])
            model[i].eval()
    else:
        model=mynet.__dict__[args.arch](pretrained=None,num_classes=PRIMES)
        model.load_state_dict(torch.load(args.checkpoint,
                map_location=lambda storage, loc: storage)['state'])
        if torch.cuda.is_available():
                model = model.cuda(device)
                
    print('Finished loading model!')
    f1=open("label_id.pkl","rb")
    label2id=pickle.load(f1)
    maxlabel=sorted(label2id.keys())[-1]
    assert maxlabel+1 == NLABEL
    f1.close()
    
    '''
    sample_path='recognition_sample_submission.csv'
    df=pd.read_csv(sample_path,index_col=0)
    nones=df.index.drop(image_ids)
    '''
    
    p0=PRIMES[0]
    p1=PRIMES[1]

    dp=p1-p0
    res=[]
    of=open("results.csv",mode='w+')
    of.write('id,landmarks\n')
    for j in range(dp):
        res.append((-p0*j)%dp)
        
    def tolabel(i):
        j=res.index(i%dp)
        return (i+j*p0)//dp*p1
        
    t02=time.time()
    ii=0
    total=len(image_ids)
    with torch.no_grad():
        for ib, data in enumerate(dataloader):
            inputs = data[0]["data"].to(device, non_blocking=True)
            nn=inputs.shape[0]
            
            softmax=torch.nn.Softmax(dim=1)
            if args.arch in model_names:
                outputs=model[0](inputs)
                scores=softmax(outputs).unsqueeze(2).expand((nn,p0,p1))
                
                outputs=model[1](inputs)
                scores= scores * softmax(outputs).unsqueeze(1).expand((nn,p0,p1))
                
            else:
                outputs=model(inputs)
                scores=softmax(outputs[0]).unsqueeze(1).expand((nn,p1,p0)).reshape((nn,p0*p1))
                scores= scores * softmax(outputs[1]).unsqueeze(1).expand((nn,p0,p1)).reshape((nn,p0*p1))
                
                scores[:,NLABEL:]=0
                scores = scores/scores.sum(dim=1,keepdim=True)
                conf, pred = scores.max(dim=1)
                for j in range(nn):
                    if ii< len(ids):
                        of.write('{:s},{:d} {:.6f}\n'.format(ids[ii].split('.')[0],
                                 label2id[pred[j].item()],conf[j].item()))
                        ii=ii+1
                        
            t01= t02
            t02= time.time()
            dt1=(t02-t01)
            if (ib+1)%10==0:
                print('Image {:d}/{:d} time: {:.4f}s'.format(ii,total,dt1))
    of.close()
예제 #8
0
def infer(model,
          path,
          detections_file,
          resize,
          max_size,
          batch_size,
          mixed_precision=True,
          is_master=True,
          world=0,
          annotations=None,
          use_dali=True,
          verbose=True):
    'Run inference on images from path'

    backend = 'pytorch' if isinstance(model, Model) else 'tensorrt'

    # Create annotations if none was provided
    if not annotations:
        annotations = tempfile.mktemp('.json')
        images = [{
            'id': i,
            'file_name': f
        } for i, f in enumerate(os.listdir(path))]
        json.dump({'images': images}, open(annotations, 'w'))

    # TensorRT only supports fixed input sizes, so override input size accordingly
    if backend == 'tensorrt': max_size = max(model.input_size)

    # Prepare dataset
    if verbose: print('Preparing dataset...')
    data_iterator = (DaliDataIterator if use_dali else DataIterator)(
        path,
        resize,
        max_size,
        batch_size,
        model.stride,
        world,
        annotations,
        training=False)
    if verbose: print(data_iterator)

    # Prepare model
    if backend is 'pytorch':
        if torch.cuda.is_available():
            model = model.cuda()
        if mixed_precision:
            model = network_to_half(model)
        model.eval()

    if verbose:
        print('   backend: {}'.format(backend))
        print('    device: {} {}'.format(
            world, 'cpu' if not torch.cuda.is_available() else
            'gpu' if world == 1 else 'gpus'))
        print('     batch: {}, precision: {}'.format(
            batch_size, 'unknown' if backend is 'tensorrt' else
            'mixed' if mixed_precision else 'full'))
        print('Running inference...')

    results = []
    profiler = Profiler(['infer', 'fw'])
    with torch.no_grad():
        for i, (data, ids, ratios) in enumerate(data_iterator):
            # Forward pass
            profiler.start('fw')
            scores, boxes, classes = model(data)
            profiler.stop('fw')

            results.append([scores, boxes, classes, ids, ratios])

            profiler.bump('infer')
            if verbose and (profiler.totals['infer'] > 60
                            or i == len(data_iterator) - 1):
                size = len(data_iterator.ids)
                msg = '[{:{len}}/{}]'.format(min((i + 1) * batch_size, size),
                                             size,
                                             len=len(str(size)))
                msg += ' {:.3f}s/{}-batch'.format(profiler.means['infer'],
                                                  batch_size)
                msg += ' (fw: {:.3f}s)'.format(profiler.means['fw'])
                msg += ', {:.1f} im/s'.format(batch_size /
                                              profiler.means['infer'])
                print(msg, flush=True)

                profiler.reset()

    # Gather results from all devices
    if verbose: print('Gathering results...')
    results = [torch.cat(r, dim=0) for r in zip(*results)]
    if world > 1:
        for r, result in enumerate(results):
            all_result = [
                torch.ones_like(result, device=result.device)
                for _ in range(world)
            ]
            torch.distributed.all_gather(list(all_result), result)
            results[r] = torch.cat(all_result, dim=0)

    if is_master:
        # Copy buffers back to host
        results = [r.cpu() for r in results]

        # Collect detections
        detections = []
        for scores, boxes, classes, ids, ratios in zip(*results):
            keep = (scores > 0).nonzero()
            scores = scores[keep].view(-1)
            boxes = boxes[keep, :].view(-1, 4) / ratios
            classes = classes[keep].view(-1).int()

            for score, box, cat in zip(scores, boxes, classes):
                x1, y1, x2, y2 = box.data.tolist()
                cat = cat.item()
                if 'annotations' in data_iterator.coco.dataset:
                    cat = data_iterator.coco.getCatIds()[cat]
                detections.append({
                    'image_id': ids.item(),
                    'score': score.item(),
                    'bbox': [x1, y1, x2 - x1 + 1, y2 - y1 + 1],
                    'category_id': cat
                })

        if detections:
            # Save detections
            if detections_file and verbose:
                print('Writing {}...'.format(detections_file))
            detections = {'annotations': detections}
            detections['images'] = data_iterator.coco.dataset['images']
            if 'categories' in data_iterator.coco.dataset:
                detections['categories'] = [
                    data_iterator.coco.dataset['categories']
                ]
            if detections_file:
                json.dump(detections, open(detections_file, 'w'), indent=4)

            # Evaluate model on dataset
            if 'annotations' in data_iterator.coco.dataset:
                if verbose: print('Evaluating model...')
                with redirect_stdout(None):
                    coco_pred = data_iterator.coco.loadRes(
                        detections['annotations'])
                    coco_eval = COCOeval(data_iterator.coco, coco_pred, 'bbox')
                    coco_eval.evaluate()
                    coco_eval.accumulate()
                coco_eval.summarize()
        else:
            print('No detections!')
예제 #9
0
def main():
    if args.fp16:
        assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled."

    if not os.path.exists(args.save_folder):
        os.mkdir(args.save_folder)

    csv_file = os.path.join(args.data, 'train.csv')
    df = pd.read_csv(csv_file, index_col=0)
    df = df.drop(['url'], axis=1)

    df_count = df.groupby('landmark_id').size()
    df_count = df_count.sort_values()
    df_count = df_count.to_frame('count')
    df_count['label'] = np.arange(len(df_count))
    label_dict = df_count.loc[:, 'label'].to_dict()

    df['label'] = df['landmark_id'].map(label_dict)
    label_start = df_count[df_count['count'] > 2].iloc[0, 1]
    df2 = df.loc[df['label'] >= label_start]

    r = df2.shape[0]
    rs = np.int(r / 50)
    print('Number of images:', df.shape[0])
    print('Number of labels:', df_count.shape[0])
    print('We sampled ', rs, 'starting from label', label_start,
          'as validation data')

    labels = dict()
    labels['val'] = df2['label'].sample(n=rs)
    labels['train'] = df['label'].drop(labels['val'].index)

    txt_path = dict()
    for phase in ['train', 'val']:
        txt_path[phase] = os.path.join(args.data, phase + '.txt')
        file1 = open(txt_path[phase], "w")
        lc1 = labels[phase].index.tolist()
        lc2 = labels[phase].tolist()
        for id, ll in zip(lc1, lc2):
            file1.write(id[0] + '/' + id[1] + '/' + id[2] + '/' + id + '.jpg' +
                        ' ' + str(ll) + '\n')
        file1.close()
    del df, df_count, df2, labels, label_dict

    crop_size = 224
    val_size = 256
    dataloader = dict()

    print('use ' + ['GPU', 'CPU'][args.dali_cpu] + ' to load data')
    print('Half precision:' + str(args.fp16))

    pipe = HybridTrainPipe(batch_size=args.batch_size,
                           num_threads=args.workers,
                           device_id=args.local_rank,
                           data_dir=args.data,
                           crop=crop_size,
                           dali_cpu=args.dali_cpu,
                           file_list=txt_path['train'])
    pipe.build()
    dataloader['train'] = DALIClassificationIterator(
        pipe, size=int(pipe.epoch_size("Reader") / args.world_size))

    pipe = HybridValPipe(batch_size=args.batch_size,
                         num_threads=args.workers,
                         device_id=args.local_rank,
                         data_dir=args.data,
                         crop=crop_size,
                         size=val_size,
                         file_list=txt_path['val'])
    pipe.build()
    dataloader['val'] = DALIClassificationIterator(
        pipe, size=int(pipe.epoch_size("Reader") / args.world_size))

    model = []

    if torch.cuda.is_available():
        device = torch.device("cuda:0")
        torch.cuda.set_device(device)
    else:
        device = torch.device("cpu")

    criterion = nn.CrossEntropyLoss().cuda()
    model = [None] * len(PRIMES)
    optimizer = [None] * len(PRIMES)
    scheduler = [None] * len(PRIMES)

    if args.arch in model_names:
        for i, p in enumerate(PRIMES):
            model[i] = models.__dict__[args.arch](num_classes=p)
            if not args.checkpoint:
                model_type = ''.join([i for i in args.arch if not i.isdigit()])
                model_url = models.__dict__[model_type].model_urls[args.arch]
                pre_trained = model_zoo.load_url(model_url)
                pre_trained['fc.weight'] = pre_trained['fc.weight'][:p, :]
                pre_trained['fc.bias'] = pre_trained['fc.bias'][:p]
                model[i].load_state_dict(pre_trained)
            elif args.checkpoint:
                print('Resuming training from epoch {}, loading {}...'.format(
                    args.resume_epoch, args.checkpoint))
                check_file = os.path.join(args.data, args.checkpoint)
                model[i].load_state_dict(
                    torch.load(check_file['state_' + str(p)],
                               map_location=lambda storage, loc: storage))
            if torch.cuda.is_available():
                model[i] = model[i].cuda(device)
                if args.fp16:
                    model[i] = network_to_half(model[i])
        for i, p in enumerate(PRIMES):
            optimizer[i] = optim.SGD(model[i].parameters(),
                                     lr=args.lr,
                                     momentum=0.9,
                                     weight_decay=args.weight_decay)
            if args.checkpoint:
                model[i].load_state_dict(
                    torch.load(args.checkpoint,
                               map_location=lambda storage, loc: storage.cuda(
                                   args.gpu))['state_' + str(p)])
            scheduler[i] = optim.lr_scheduler.StepLR(optimizer[i],
                                                     step_size=args.step_size,
                                                     gamma=0.1)
            for i in range(args.resume_epoch):
                scheduler[i].step()
    else:
        if args.checkpoint:
            model = mynet.__dict__[args.arch](pretrained=None,
                                              num_classes=PRIMES)
            model.load_state_dict(
                torch.load(args.checkpoint,
                           map_location=lambda storage, loc: storage)['state'])
        else:
            model = mynet.__dict__[args.arch](pretrained='imagenet',
                                              num_classes=PRIMES)

        if torch.cuda.is_available():
            model = model.cuda(device)
        optimizer = optim.SGD(model.parameters(),
                              lr=args.lr,
                              momentum=0.9,
                              weight_decay=args.weight_decay)
        if args.checkpoint and args.resume_epoch < 0:
            optimizer.load_state_dict(
                torch.load(args.checkpoint,
                           map_location=lambda storage, loc: storage)['optim'])
        scheduler = optim.lr_scheduler.StepLR(optimizer,
                                              step_size=args.step_size,
                                              gamma=0.1)
        for i in range(args.resume_epoch):
            scheduler.step()
        if args.fp16:
            model = network_to_half(model)
            optimizer = FP16_Optimizer(
                optimizer,
                static_loss_scale=args.static_loss_scale,
                dynamic_loss_scale=args.dynamic_loss_scale)

    best_acc = 0
    for epoch in range(args.resume_epoch, args.epochs):
        print('Epoch {}/{}'.format(epoch + 1, args.epochs))
        print('-' * 5)
        for phase in ['train', 'val']:
            if args.arch in model_names:
                if phase == 'train':
                    for i, p in enumerate(PRIMES):
                        scheduler[i].step()
                        model[i].train()
                else:
                    for i, p in enumerate(PRIMES):
                        model[i].eval()
            else:
                if phase == 'train':
                    scheduler.step()
                    model.train()
                else:
                    model.eval()

            num = 0
            csum = 0
            running_loss = 0.0
            cur = 0
            cur_loss = 0.0

            print(phase, ':')
            end = time.time()
            for ib, data in enumerate(dataloader[phase]):
                data_time = time.time() - end
                inputs = data[0]["data"].to(device, non_blocking=True)
                targets = data[0]["label"].squeeze().to(device,
                                                        non_blocking=True)
                if args.arch in model_names:
                    for i, p in enumerate(PRIMES):
                        optimizer[i].zero_grad()
                else:
                    optimizer.zero_grad()

                batch_size = targets.size(0)
                correct = torch.ones((batch_size),
                                     dtype=torch.uint8).to(device)
                with torch.set_grad_enabled(phase == 'train'):
                    if args.arch in model_names:
                        for i, p in enumerate(PRIMES):
                            outputs = model[i](inputs)
                            targetp = (targets % p).long()
                            loss = criterion(outputs, targetp)
                            if phase == 'train':
                                #loader_len = int(dataloader[phase]._size / args.batch_size)
                                #adjust_learning_rate(optimizer[i], epoch,ib+1, loader_len)
                                if args.fp16:
                                    optimizer[i].backward(loss)
                                else:
                                    loss.backward()
                                optimizer[i].step()
                            _, pred = outputs.topk(1, 1, True, True)
                            correct = correct.mul(pred.view(-1).eq(targetp))
                    elif args.arch in mynet.__dict__:
                        outputs = model(inputs)
                        loss = 0.0
                        for i, p in enumerate(PRIMES):
                            targetp = (targets % p).long()
                            loss += criterion(outputs[i], targetp)
                            _, pred = outputs[i].topk(1, 1, True, True)
                            correct = correct.mul(pred.view(-1).eq(targetp))
                        if phase == 'train':
                            if args.fp16:
                                optimizer.backward(loss)
                            else:
                                loss.backward()
                            optimizer.step()

                num += batch_size
                csum += correct.float().sum(0)
                acc1 = csum / num * 100
                running_loss += loss.item() * batch_size
                average_loss = running_loss / num
                cur += batch_size
                cur_loss += loss.item() * batch_size
                cur_avg_loss = cur_loss / cur
                batch_time = time.time() - end
                end = time.time()
                if (ib + 1) % args.print_freq == 0:
                    print(
                        '{} L:{:.4f} correct:{:.0f} acc1:{:.4f} data:{:.2f}s batch:{:.2f}s'
                        .format(num, cur_avg_loss, csum, acc1, data_time,
                                batch_time))
                    cur = 0
                    cur_loss = 0.0

            print('------SUMMARY:', phase, '---------')
            print('E:{} L:{:.4f} correct:{:.0f} acc1: {:.4f} Time: {:.4f}s'.
                  format(epoch, average_loss, csum, acc1, batch_time))
            dataloader[phase].reset()
        '''save the state'''
        save_file = os.path.join(args.save_folder,
                                 'epoch_' + str(epoch + 1) + '.pth')
        save_dict = {
            'epoch': epoch + 1,
            'acc': acc1,
            'arch': args.arch,
        }
        if args.arch in model_names:
            for i, p in enumerate(PRIMES):
                save_dict['state_' + str(i)] = model[i].state_dict()
                save_dict['optim_' + str(i)] = optimizer[i].state_dict()
        elif args.arch in mynet.__dict__:
            save_dict['state'] = model.state_dict()
            save_dict['optim'] = optimizer.state_dict()
            save_dict['primes'] = PRIMES
        torch.save(save_dict, save_file)
        if acc1 > best_acc:
            shutil.copyfile(save_file, 'model_best.pth.tar')
예제 #10
0
def main(args):
    """
    The main training.

    :param dataset_name: the name of the dataset from UCR.
    """
    is_debug = args.is_debug
    dataset_name = args.dataset_name
    preserve_energy = args.preserve_energy
    compress_rate = args.compress_rate
    """
    DATASET_HEADER = HEADER + ",dataset," + str(dataset_name) + \
                     "-current-preserve-energy-" + str(preserve_energy) + "\n"
    
    if args.test_compress_rates:
        dataset_log_file = os.path.join(results_folder_name,
                                        f"{args.dataset}-dataset-compress-rates.log")
    else:
        dataset_log_file = os.path.join(
            results_folder_name,
            get_log_time() + "-dataset-" + str(dataset_name) + \
            "-preserve-energy-" + str(preserve_energy) + \
            "-compress-rate-" + str(compress_rate) + \
            ".log")
        with open(dataset_log_file, "a") as file:
            # Write the metadata.
            file.write(DATASET_HEADER)
            # Write the header with the names of the columns.
            file.write(
                "epoch,train_loss,train_accuracy,dev_loss,dev_accuracy,"
                "test_loss,test_accuracy,epoch_time,learning_rate,"
                "train_time,test_time,compress_rate\n")

    # with open(os.path.join(results_dir, additional_log_file), "a") as file:
    #     # Write the metadata.
    #     file.write(DATASET_HEADER)

    with open(os.path.join(results_dir, mem_log_file), "a") as file:
        # Write the metadata.
        file.write(DATASET_HEADER)
    """
    torch.manual_seed(args.seed)

    optimizer_type = args.optimizer_type
    scheduler_type = args.scheduler_type
    loss_type = args.loss_type
    loss_reduction = args.loss_reduction

    use_cuda = args.use_cuda
    device = torch.device("cuda" if use_cuda else "cpu")
    tensor_type = args.tensor_type
    if use_cuda and args.noise_sigma is False:
        if tensor_type is TensorType.FLOAT32:
            cuda_type = torch.cuda.FloatTensor
        elif tensor_type is TensorType.FLOAT16:
            cuda_type = torch.cuda.HalfTensor
        elif tensor_type is TensorType.DOUBLE:
            cuda_type = torch.cuda.DoubleTensor
        else:
            raise Exception(f"Unknown tensor type: {tensor_type}")
        # The below has to be disabled for normal distribution to work (add noise).
        torch.set_default_tensor_type(cuda_type)
    elif use_cuda is False and args.noise_sigma is False:
        if tensor_type is TensorType.FLOAT32:
            cpu_type = torch.FloatTensor
        elif tensor_type is TensorType.FLOAT16:
            cpu_type = torch.HalfTensor
        elif tensor_type is TensorType.DOUBLE:
            cpu_type = torch.DoubleTensor
        else:
            raise Exception(f"Unknown tensor type: {tensor_type}")
        torch.set_default_tensor_type(cpu_type)

    train_loader, dev_loader, test_loader = None, None, None
    if dataset_name is "cifar10" or dataset_name is "cifar100":
        train_loader, test_loader, _, _ = get_cifar(args, dataset_name)
    elif dataset_name is "mnist":
        train_loader, test_loader = get_mnist(args)
    elif dataset_name in os.listdir(ucr_path):  # dataset from UCR archive
        train_loader, test_loader, dev_loader = get_ucr(args)
    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")

    model = getModelPyTorch(args=args)
    model.to(device)
    # model = torch.nn.DataParallel(model)

    # https://pytorch.org/docs/master/notes/serialization.html
    if args.model_path != "no_model":
        model.load_state_dict(
            torch.load(os.path.join(models_dir, args.model_path),
                       map_location=device))
        msg = "loaded model: " + args.model_path
        # logger.info(msg)
        print(msg)
    if args.precision_type is PrecisionType.FP16:
        # model.half()  # convert to half precision
        model = network_to_half(model)
    """
    You want to make sure that the BatchNormalization layers use float32 for 
    accumulation or you will have convergence issues.
    https://discuss.pytorch.org/t/training-with-half-precision/11815
    """
    # for layer in model.modules():
    #     if isinstance(layer, nn.BatchNorm1d) or isinstance(layer,
    #                                                        nn.BatchNorm2d):
    #         layer.float()

    params = model.parameters()
    eps = 1e-8

    if optimizer_type is OptimizerType.MOMENTUM:
        optimizer = optim.SGD(params,
                              lr=args.learning_rate,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)
    elif optimizer_type is OptimizerType.ADAM_FLOAT16:
        optimizer = AdamFloat16(params, lr=args.learning_rate, eps=eps)
    elif optimizer_type is OptimizerType.ADAM:
        optimizer = optim.Adam(params,
                               lr=args.learning_rate,
                               betas=(args.adam_beta1, args.adam_beta2),
                               weight_decay=args.weight_decay,
                               eps=eps)
    else:
        raise Exception(f"Unknown optimizer type: {optimizer_type.name}")

    # https://pytorch.org/docs/stable/optim.html#torch.optim.lr_scheduler.ReduceLROnPlateau
    if scheduler_type is SchedulerType.ReduceLROnPlateau:
        scheduler = ReduceLROnPlateauPyTorch(optimizer=optimizer,
                                             mode='min',
                                             factor=0.1,
                                             patience=10)
    elif scheduler_type is SchedulerType.MultiStepLR:
        scheduler = MultiStepLR(optimizer=optimizer, milestones=[150, 250])
    else:
        raise Exception(f"Unknown scheduler type: {scheduler_type}")

    if args.precision_type is PrecisionType.FP16:

        #amp_handle: tells it where back-propagation occurs so that it can
        #properly scale the loss and clear internal per-iteration state.

        # amp_handle = amp.init()
        # optimizer = amp_handle.wrap_optimizer(optimizer)

        # The optimizer supported by apex.
        optimizer = FP16_Optimizer(optimizer,
                                   static_loss_scale=args.static_loss_scale,
                                   dynamic_loss_scale=args.dynamic_loss_scale,
                                   verbose=True)
    """
    # max = choose the best model.
    min_train_loss = min_test_loss = min_dev_loss = sys.float_info.max
    max_train_accuracy = max_test_accuracy = max_dev_accuracy = 0.0
    """
    # Optionally resume from a checkpoint.
    if args.resume:
        # Use a local scope to avoid dangling references.
        def resume():
            if os.path.isfile(args.resume):
                #print("=> loading checkpoint '{}'".format(args.resume))
                checkpoint = torch.load(
                    args.resume,
                    map_location=lambda storage, loc: storage.cuda(args.gpu))
                args.start_epoch = checkpoint['epoch']
                max_train_accuracy = checkpoint['max_train_accuracy']
                model.load_state_dict(checkpoint['state_dict'])
                # An FP16_Optimizer instance's state dict internally stashes the master params.
                optimizer.load_state_dict(checkpoint['optimizer'])
                #print("=> loaded checkpoint '{}' (epoch {})"
                #      .format(args.resume, checkpoint['epoch']))
                return max_train_accuracy
            else:
                print("=> no checkpoint found at '{}'".format(args.resume))
                return 0.0

        max_train_accuracy = resume()
    ""
    if loss_reduction is LossReduction.ELEMENTWISE_MEAN:
        reduction_function = "mean"
    elif loss_reduction is LossReduction.SUM:
        reduction_function = "sum"
    else:
        raise Exception(f"Unknown loss reduction: {loss_reduction}")

    if loss_type is LossType.CROSS_ENTROPY:
        loss_function = torch.nn.CrossEntropyLoss(reduction=reduction_function)
    elif loss_type is LossType.NLL:
        loss_function = torch.nn.NLLLoss(reduction=reduction_function)
    else:
        raise Exception(f"Unknown loss type: {loss_type}")
    """
    if args.visulize is True:
        start_visualize_time = time.time()
        test_loss, test_accuracy = test(
            model=model, device=device, test_loader=test_loader,
            loss_function=loss_function, args=args)
        elapsed_time = time.time() - start_visualize_time
        print("test time: ", elapsed_time)
        print("test accuracy: ", test_accuracy)
        with open(global_log_file, "a") as file:
            file.write(
                dataset_name + ",None,None,None,None," + str(
                    test_loss) + "," + str(test_accuracy) + "," + str(
                    elapsed_time) + ",visualize," + str(
                    args.preserve_energy) + "," + str(
                    args.compress_rate) + "\n")
        return

    dataset_start_time = time.time()
    dev_loss = min_dev_los = sys.float_info.max
    dev_accuracy = 0.0
    for epoch in range(args.start_epoch, args.epochs + 1):
        epoch_start_time = time.time()
        # print("\ntrain:")
        if args.log_conv_size is True:
            with open(additional_log_file, "a") as file:
                file.write(str(args.compress_rate) + ",")
        train_start_time = time.time()
        train_loss, train_accuracy = train(
            model=model, device=device, train_loader=train_loader, args=args,
            optimizer=optimizer, loss_function=loss_function, epoch=epoch)
        train_time = time.time() - train_start_time
        if args.is_dev_dataset:
            if dev_loader is None:
                raise Exception("The dev_loader was not set! Check methods to"
                                "get the data, e.g. get_ucr()")
            dev_loss, dev_accuracy = test(
                model=model, device=device, test_loader=dev_loader,
                loss_function=loss_function, args=args)
        # print("\ntest:")
        test_start_time = time.time()
        if args.log_conv_size is True or args.mem_test is True:
            test_loss, test_accuracy = 0, 0
        else:
            test_loss, test_accuracy = test(
                model=model, device=device, test_loader=test_loader,
                loss_function=loss_function, args=args)
        test_time = time.time() - test_start_time
        # Scheduler step is based only on the train data, we do not use the
        # test data to schedule the decrease in the learning rate.
        scheduler.step(train_loss)

        epoch_time = time.time() - epoch_start_time

        raw_optimizer = optimizer
        if args.precision_type is PrecisionType.FP16:
            raw_optimizer = optimizer.optimizer
        lr = f"unknown (started with: {args.learning_rate})"
        if len(raw_optimizer.param_groups) > 0:
            lr = raw_optimizer.param_groups[0]['lr']

        with open(dataset_log_file, "a") as file:
            file.write(str(epoch) + "," + str(train_loss) + "," + str(
                train_accuracy) + "," + str(dev_loss) + "," + str(
                dev_accuracy) + "," + str(test_loss) + "," + str(
                test_accuracy) + "," + str(epoch_time) + "," + str(
                lr) + "," + str(train_time) + "," + str(test_time) + "," + str(
                args.compress_rate) + "\n")

        # Metric: select the best model based on the best train loss (minimal).
        is_best = False
        if (epoch == args.start_epoch) or (train_loss < min_train_loss) or (
                dev_loss < min_dev_loss):
            min_train_loss = train_loss
            max_train_accuracy = train_accuracy
            min_dev_loss = dev_loss
            max_dev_accuracy = dev_accuracy
            min_test_loss = test_loss
            max_test_accuracy = test_accuracy
            is_best = True
            model_path = os.path.join(models_dir,
                                      get_log_time() + "-dataset-" + str(
                                          dataset_name) + \
                                      "-preserve-energy-" + str(
                                          preserve_energy) + \
                                      "-compress-rate-" + str(
                                          args.compress_rate) + \
                                      "-test-accuracy-" + str(
                                          test_accuracy) + ".model")
            torch.save(model.state_dict(), model_path)

        # Save the checkpoint (to resume training).
        save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'max_train_accuracy': max_train_accuracy,
            'optimizer': optimizer.state_dict(),
        }, is_best,
            filename="dataset-" + dataset_name + "-max-train-accuracy-" + str(
                max_train_accuracy) + "-max-test-accuracy-" + str(
                max_test_accuracy) + "-compress-rate-" + str(
                args.compress_rate) + "-" + "checkpoint.tar")

    with open(global_log_file, "a") as file:
        file.write(dataset_name + "," + str(min_train_loss) + "," + str(
            max_train_accuracy) + "," + str(min_dev_loss) + "," + str(
            max_dev_accuracy) + "," + str(min_test_loss) + "," + str(
            max_test_accuracy) + "," + str(
            time.time() - dataset_start_time) + "\n")
    test_time = time.time()
    test(model = model, device = device, test_loader = test_loader, loss_function = loss_function, args = args)
    print("test time", time.time() - test_time)
    """
    data = []
    for idx, (data_input, label) in enumerate(test_loader):
        #data_input = data_input.to(device=device, dtype=args.dtype)
        data_input.numpy()
        data.append(data_input)

    #inference(model = model, device = device, data = data[0], args = args)

    queue_1 = Queue()
    item_1 = image_loader(data, queue_1)
    item_2 = inf(queue_1, model, device, args)
    p_1 = Process(target=next, args=item_1, daemon=True)
    p_2 = Process(target=next, args=item_2, daemon=True)
    p_2.start()
    p_1.start()
예제 #11
0
def create_optimizer(net,
                     name,
                     learning_rate,
                     weight_decay,
                     momentum=0,
                     fp16_loss_scale=None,
                     optimizer_state=None,
                     device=None):
    net.float()

    use_fp16 = fp16_loss_scale is not None
    if use_fp16:
        from apex import fp16_utils
        net = fp16_utils.network_to_half(net)

    device = choose_device(device)
    print('use', device)
    if device.type == 'cuda':
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = True
    net = net.to(device)

    # optimizer
    parameters = [p for p in net.parameters() if p.requires_grad]
    print('N of parameters', len(parameters))

    if name == 'sgd':
        optimizer = optim.SGD(parameters,
                              lr=learning_rate,
                              momentum=momentum,
                              weight_decay=weight_decay)
    elif name == 'adamw':
        from .adamw import AdamW
        optimizer = AdamW(parameters,
                          lr=learning_rate,
                          weight_decay=weight_decay)
    elif name == 'adam':
        optimizer = optim.Adam(parameters,
                               lr=learning_rate,
                               weight_decay=weight_decay)
    else:
        raise NotImplementedError(name)

    if use_fp16:
        from apex import fp16_utils
        if fp16_loss_scale == 0:
            opt_args = dict(dynamic_loss_scale=True)
        else:
            opt_args = dict(static_loss_scale=fp16_loss_scale)
        print('FP16_Optimizer', opt_args)
        optimizer = fp16_utils.FP16_Optimizer(optimizer, **opt_args)
    else:
        optimizer.backward = lambda loss: loss.backward()

    if optimizer_state:
        if use_fp16 and 'optimizer_state_dict' not in optimizer_state:
            # resume FP16_Optimizer.optimizer only
            optimizer.optimizer.load_state_dict(optimizer_state)
        elif use_fp16 and 'optimizer_state_dict' in optimizer_state:
            # resume optimizer from FP16_Optimizer.optimizer
            optimizer.load_state_dict(optimizer_state['optimizer_state_dict'])
        else:
            optimizer.load_state_dict(optimizer_state)

    return net, optimizer