コード例 #1
0
def main():
    # os.system('shutdown -c')  # cancel previous shutdown command

    if write_log:
        utils.makedirs(args.save)
        logger = utils.get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(__file__))

        logger.info(args)

        args_file_path = os.path.join(args.save, 'args.yaml')
        with open(args_file_path, 'w') as f:
            yaml.dump(vars(args), f, default_flow_style=False)

    if args.distributed:
        if write_log: logger.info('Distributed initializing process group')
        torch.cuda.set_device(args.local_rank)
        distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                       world_size=dist_utils.env_world_size(), rank=env_rank())
        assert (dist_utils.env_world_size() == distributed.get_world_size())
        if write_log: logger.info("Distributed: success (%d/%d)" % (args.local_rank, distributed.get_world_size()))
        device = torch.device("cuda:%d" % torch.cuda.current_device() if torch.cuda.is_available() else "cpu")
    else:
        device = torch.cuda.current_device()  #

    # import pdb; pdb.set_trace()
    cvt = lambda x: x.type(torch.float32).to(device, non_blocking=True)

    # load dataset
    train_loader, test_loader, data_shape = get_dataset(args)

    trainlog = os.path.join(args.save, 'training.csv')
    testlog = os.path.join(args.save, 'test.csv')

    traincolumns = ['itr', 'wall', 'itr_time', 'loss', 'bpd', 'fe', 'total_time', 'grad_norm']
    testcolumns = ['wall', 'epoch', 'eval_time', 'bpd', 'fe', 'total_time', 'transport_cost']

    # build model
    regularization_fns, regularization_coeffs = create_regularization_fns(args)
    model = create_model(args, data_shape, regularization_fns).cuda()
    if args.distributed: model = dist_utils.DDP(model,
                                                device_ids=[args.local_rank],
                                                output_device=args.local_rank)

    traincolumns = append_regularization_keys_header(traincolumns, regularization_fns)

    if not args.resume and write_log:
        with open(trainlog, 'w') as f:
            csvlogger = csv.DictWriter(f, traincolumns)
            csvlogger.writeheader()
        with open(testlog, 'w') as f:
            csvlogger = csv.DictWriter(f, testcolumns)
            csvlogger.writeheader()

    set_cnf_options(args, model)

    if write_log: logger.info(model)
    if write_log: logger.info("Number of trainable parameters: {}".format(count_parameters(model)))
    if write_log: logger.info('Iters per train epoch: {}'.format(len(train_loader)))
    if write_log: logger.info('Iters per test: {}'.format(len(test_loader)))

    # optimizer
    if args.optimizer == 'adam':
        optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    elif args.optimizer == 'sgd':
        optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=0.9,
                              nesterov=False)

    # restore parameters
    # import pdb; pdb.set_trace()
    if args.resume is not None:
        # import pdb; pdb.set_trace()
        print('resume from checkpoint')
        checkpt = torch.load(args.resume, map_location=lambda storage, loc: storage.cuda(args.local_rank))
        model.load_state_dict(checkpt["state_dict"])
        if "optim_state_dict" in checkpt.keys():
            optimizer.load_state_dict(checkpt["optim_state_dict"])
            # Manually move optimizer state to device.
            for state in optimizer.state.values():
                for k, v in state.items():
                    if torch.is_tensor(v):
                        state[k] = cvt(v)

    # For visualization.
    if write_log: fixed_z = cvt(torch.randn(min(args.test_batch_size, 100), *data_shape))

    if write_log:
        time_meter = utils.RunningAverageMeter(0.97)
        bpd_meter = utils.RunningAverageMeter(0.97)
        loss_meter = utils.RunningAverageMeter(0.97)
        steps_meter = utils.RunningAverageMeter(0.97)
        grad_meter = utils.RunningAverageMeter(0.97)
        tt_meter = utils.RunningAverageMeter(0.97)

    if not args.resume:
        best_loss = float("inf")
        itr = 0
        wall_clock = 0.
        begin_epoch = 1
        chkdir = args.save
        '''
    elif args.resume and args.validate:
        chkdir = os.path.dirname(args.resume)
        wall_clock = 0
        itr = 0
        best_loss = 0.0
        begin_epoch = 0
        '''
    else:
        chkdir = os.path.dirname(args.resume)
        filename = os.path.join(chkdir, 'test.csv')
        print(filename)
        tedf = pd.read_csv(os.path.join(chkdir, 'test.csv'))
        trdf = pd.read_csv(os.path.join(chkdir, 'training.csv'))
        # import pdb; pdb.set_trace()
        wall_clock = trdf['wall'].to_numpy()[-1]
        itr = trdf['itr'].to_numpy()[-1]
        best_loss = tedf['bpd'].min()
        begin_epoch = int(tedf['epoch'].to_numpy()[-1] + 1)  # not exactly correct

    if args.distributed:
        if write_log: logger.info('Syncing machines before training')
        dist_utils.sum_tensor(torch.tensor([1.0]).float().cuda())

    for epoch in range(begin_epoch, begin_epoch + 1):
        # compute test loss
        print('Evaluating')
        model.eval()
        if args.local_rank == 0:
            utils.makedirs(args.save)
            # import pdb; pdb.set_trace()
            if hasattr(model, 'module'):
                _state = model.module.state_dict()
            else:
                _state = model.state_dict()
            torch.save({
                "args": args,
                "state_dict": _state,  # model.module.state_dict() if torch.cuda.is_available() else model.state_dict(),
                "optim_state_dict": optimizer.state_dict(),
                "fixed_z": fixed_z.cpu()
            }, os.path.join(args.save, "checkpt_%d.pth" % epoch))

        # save real and generate with different temperatures
        fig_num = 64
        if True:  # args.save_real:
            for i, (x, y) in enumerate(test_loader):
                if i < 100:
                    pass
                elif i == 100:
                    real = x.size(0)
                else:
                    break
            if x.shape[0] > fig_num:
                x = x[:fig_num, ...]
            # import pdb; pdb.set_trace()
            fig_filename = os.path.join(chkdir, "real.jpg")
            save_image(x.float() / 255.0, fig_filename, nrow=8)

        if True:  # args.generate:
            print('\nGenerating images... ')
            fixed_z = cvt(torch.randn(fig_num, *data_shape))
            nb = int(np.ceil(np.sqrt(float(fixed_z.size(0)))))
            for t in [ 1.0, 0.99, 0.98, 0.97,0.96,0.95,0.93,0.92,0.90,0.85,0.8,0.75,0.7,0.65,0.6]:
                # visualize samples and density
                fig_filename = os.path.join(chkdir, "generated-T%g.jpg" % t)
                utils.makedirs(os.path.dirname(fig_filename))
                generated_samples = model(t * fixed_z, reverse=True)
                x = unshift(generated_samples[0].view(-1, *data_shape), 8)
                save_image(x, fig_filename, nrow=nb)
コード例 #2
0
ファイル: train.py プロジェクト: kinaanaamir/ffjord-rnode
def main():
    #os.system('shutdown -c')  # cancel previous shutdown command

    if write_log:
        utils.makedirs(args.save)
        logger = utils.get_logger(logpath=os.path.join(args.save, 'logs'),
                                  filepath=os.path.abspath(__file__))

        logger.info(args)

        args_file_path = os.path.join(args.save, 'args.yaml')
        with open(args_file_path, 'w') as f:
            yaml.dump(vars(args), f, default_flow_style=False)

    if args.distributed:
        if write_log: logger.info('Distributed initializing process group')
        torch.cuda.set_device(args.local_rank)
        distributed.init_process_group(backend=args.dist_backend,
                                       init_method=args.dist_url,
                                       world_size=dist_utils.env_world_size(),
                                       rank=env_rank())
        assert (dist_utils.env_world_size() == distributed.get_world_size())
        if write_log:
            logger.info("Distributed: success (%d/%d)" %
                        (args.local_rank, distributed.get_world_size()))

    # get deivce
    # device = torch.device("cuda:%d"%torch.cuda.current_device() if torch.cuda.is_available() else "cpu")
    device = "cpu"
    cvt = lambda x: x.type(torch.float32).to(device, non_blocking=True)

    # load dataset
    train_loader, test_loader, data_shape = get_dataset(args)

    trainlog = os.path.join(args.save, 'training.csv')
    testlog = os.path.join(args.save, 'test.csv')

    traincolumns = [
        'itr', 'wall', 'itr_time', 'loss', 'bpd', 'fe', 'total_time',
        'grad_norm'
    ]
    testcolumns = [
        'wall', 'epoch', 'eval_time', 'bpd', 'fe', 'total_time',
        'transport_cost'
    ]

    # build model
    regularization_fns, regularization_coeffs = create_regularization_fns(args)
    model = create_model(args, data_shape, regularization_fns)
    # model = model.cuda()
    if args.distributed:
        model = dist_utils.DDP(model,
                               device_ids=[args.local_rank],
                               output_device=args.local_rank)

    traincolumns = append_regularization_keys_header(traincolumns,
                                                     regularization_fns)

    if not args.resume and write_log:
        with open(trainlog, 'w') as f:
            csvlogger = csv.DictWriter(f, traincolumns)
            csvlogger.writeheader()
        with open(testlog, 'w') as f:
            csvlogger = csv.DictWriter(f, testcolumns)
            csvlogger.writeheader()

    set_cnf_options(args, model)

    if write_log: logger.info(model)
    if write_log:
        logger.info("Number of trainable parameters: {}".format(
            count_parameters(model)))
    if write_log:
        logger.info('Iters per train epoch: {}'.format(len(train_loader)))
    if write_log: logger.info('Iters per test: {}'.format(len(test_loader)))

    # optimizer
    if args.optimizer == 'adam':
        optimizer = optim.Adam(model.parameters(),
                               lr=args.lr,
                               weight_decay=args.weight_decay)
    elif args.optimizer == 'sgd':
        optimizer = optim.SGD(model.parameters(),
                              lr=args.lr,
                              weight_decay=args.weight_decay,
                              momentum=0.9,
                              nesterov=False)

    # restore parameters
    if args.resume is not None:
        checkpt = torch.load(
            args.resume,
            map_location=lambda storage, loc: storage.cuda(args.local_rank))
        model.load_state_dict(checkpt["state_dict"])
        if "optim_state_dict" in checkpt.keys():
            optimizer.load_state_dict(checkpt["optim_state_dict"])
            # Manually move optimizer state to device.
            for state in optimizer.state.values():
                for k, v in state.items():
                    if torch.is_tensor(v):
                        state[k] = cvt(v)

    # For visualization.
    if write_log:
        fixed_z = cvt(torch.randn(min(args.test_batch_size, 100), *data_shape))

    if write_log:
        time_meter = utils.RunningAverageMeter(0.97)
        bpd_meter = utils.RunningAverageMeter(0.97)
        loss_meter = utils.RunningAverageMeter(0.97)
        steps_meter = utils.RunningAverageMeter(0.97)
        grad_meter = utils.RunningAverageMeter(0.97)
        tt_meter = utils.RunningAverageMeter(0.97)

    if not args.resume:
        best_loss = float("inf")
        itr = 0
        wall_clock = 0.
        begin_epoch = 1
    else:
        chkdir = os.path.dirname(args.resume)
        tedf = pd.read_csv(os.path.join(chkdir, 'test.csv'))
        trdf = pd.read_csv(os.path.join(chkdir, 'training.csv'))
        wall_clock = trdf['wall'].to_numpy()[-1]
        itr = trdf['itr'].to_numpy()[-1]
        best_loss = tedf['bpd'].min()
        begin_epoch = int(tedf['epoch'].to_numpy()[-1] +
                          1)  # not exactly correct

    if args.distributed:
        if write_log: logger.info('Syncing machines before training')
        dist_utils.sum_tensor(torch.tensor([1.0]).float().cuda())

    for epoch in range(begin_epoch, args.num_epochs + 1):
        if not args.validate:
            model.train()

            with open(trainlog, 'a') as f:
                if write_log: csvlogger = csv.DictWriter(f, traincolumns)

                for _, (x, y) in enumerate(train_loader):
                    start = time.time()
                    update_lr(optimizer, itr)
                    optimizer.zero_grad()

                    # cast data and move to device
                    x = add_noise(cvt(x), nbits=args.nbits)
                    #x = x.clamp_(min=0, max=1)
                    # compute loss
                    bpd, (x, z), reg_states = compute_bits_per_dim(x, model)
                    if np.isnan(bpd.data.item()):
                        raise ValueError('model returned nan during training')
                    elif np.isinf(bpd.data.item()):
                        raise ValueError('model returned inf during training')

                    loss = bpd
                    if regularization_coeffs:
                        reg_loss = sum(reg_state * coeff
                                       for reg_state, coeff in zip(
                                           reg_states, regularization_coeffs)
                                       if coeff != 0)
                        loss = loss + reg_loss
                    total_time = count_total_time(model)

                    loss.backward()
                    nfe_opt = count_nfe(model)
                    if write_log: steps_meter.update(nfe_opt)
                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        model.parameters(), args.max_grad_norm)

                    optimizer.step()

                    itr_time = time.time() - start
                    wall_clock += itr_time

                    batch_size = x.size(0)
                    metrics = torch.tensor([
                        1., batch_size,
                        loss.item(),
                        bpd.item(), nfe_opt, grad_norm, *reg_states
                    ]).float()

                    rv = tuple(torch.tensor(0.) for r in reg_states)

                    total_gpus, batch_total, r_loss, r_bpd, r_nfe, r_grad_norm, *rv = dist_utils.sum_tensor(
                        metrics).cpu().numpy()

                    if write_log:
                        time_meter.update(itr_time)
                        bpd_meter.update(r_bpd / total_gpus)
                        loss_meter.update(r_loss / total_gpus)
                        grad_meter.update(r_grad_norm / total_gpus)
                        tt_meter.update(total_time)

                        fmt = '{:.4f}'
                        logdict = {
                            'itr': itr,
                            'wall': fmt.format(wall_clock),
                            'itr_time': fmt.format(itr_time),
                            'loss': fmt.format(r_loss / total_gpus),
                            'bpd': fmt.format(r_bpd / total_gpus),
                            'total_time': fmt.format(total_time),
                            'fe': r_nfe / total_gpus,
                            'grad_norm': fmt.format(r_grad_norm / total_gpus),
                        }
                        if regularization_coeffs:
                            rv = tuple(v_ / total_gpus for v_ in rv)
                            logdict = append_regularization_csv_dict(
                                logdict, regularization_fns, rv)
                        csvlogger.writerow(logdict)

                        if itr % args.log_freq == 0:
                            log_message = (
                                "Itr {:06d} | Wall {:.3e}({:.2f}) | "
                                "Time/Itr {:.2f}({:.2f}) | BPD {:.2f}({:.2f}) | "
                                "Loss {:.2f}({:.2f}) | "
                                "FE {:.0f}({:.0f}) | Grad Norm {:.3e}({:.3e}) | "
                                "TT {:.2f}({:.2f})".format(
                                    itr, wall_clock, wall_clock / (itr + 1),
                                    time_meter.val, time_meter.avg,
                                    bpd_meter.val, bpd_meter.avg,
                                    loss_meter.val, loss_meter.avg,
                                    steps_meter.val, steps_meter.avg,
                                    grad_meter.val, grad_meter.avg,
                                    tt_meter.val, tt_meter.avg))
                            if regularization_coeffs:
                                log_message = append_regularization_to_log(
                                    log_message, regularization_fns, rv)
                            logger.info(log_message)

                    itr += 1

        # compute test loss
        model.eval()
        if args.local_rank == 0:
            utils.makedirs(args.save)
            torch.save(
                {
                    "args":
                    args,
                    "state_dict":
                    model.module.state_dict()
                    if torch.cuda.is_available() else model.state_dict(),
                    "optim_state_dict":
                    optimizer.state_dict(),
                    "fixed_z":
                    fixed_z.cpu()
                }, os.path.join(args.save, "checkpt.pth"))
        if epoch % args.val_freq == 0 or args.validate:
            with open(testlog, 'a') as f:
                if write_log: csvlogger = csv.DictWriter(f, testcolumns)
                with torch.no_grad():
                    start = time.time()
                    if write_log: logger.info("validating...")

                    lossmean = 0.
                    meandist = 0.
                    steps = 0
                    tt = 0.
                    for i, (x, y) in enumerate(test_loader):
                        sh = x.shape
                        x = shift(cvt(x), nbits=args.nbits)
                        loss, (x, z), _ = compute_bits_per_dim(x, model)
                        dist = (x.view(x.size(0), -1) -
                                z).pow(2).mean(dim=-1).mean()
                        meandist = i / (i + 1) * dist + meandist / (i + 1)
                        lossmean = i / (i + 1) * lossmean + loss / (i + 1)

                        tt = i / (i + 1) * tt + count_total_time(model) / (i +
                                                                           1)
                        steps = i / (i + 1) * steps + count_nfe(model) / (i +
                                                                          1)

                    loss = lossmean.item()
                    metrics = torch.tensor([1., loss, meandist, steps]).float()

                    total_gpus, r_bpd, r_mdist, r_steps = dist_utils.sum_tensor(
                        metrics).cpu().numpy()
                    eval_time = time.time() - start

                    if write_log:
                        fmt = '{:.4f}'
                        logdict = {
                            'epoch': epoch,
                            'eval_time': fmt.format(eval_time),
                            'bpd': fmt.format(r_bpd / total_gpus),
                            'wall': fmt.format(wall_clock),
                            'total_time': fmt.format(tt),
                            'transport_cost': fmt.format(r_mdist / total_gpus),
                            'fe': '{:.2f}'.format(r_steps / total_gpus)
                        }

                        csvlogger.writerow(logdict)

                        logger.info(
                            "Epoch {:04d} | Time {:.4f}, Bit/dim {:.4f}, Steps {:.4f}, TT {:.2f}, Transport Cost {:.2e}"
                            .format(epoch, eval_time, r_bpd / total_gpus,
                                    r_steps / total_gpus, tt,
                                    r_mdist / total_gpus))

                    loss = r_bpd / total_gpus

                    if loss < best_loss and args.local_rank == 0:
                        best_loss = loss
                        shutil.copyfile(os.path.join(args.save, "checkpt.pth"),
                                        os.path.join(args.save, "best.pth"))

            # visualize samples and density
            if write_log:
                with torch.no_grad():
                    fig_filename = os.path.join(args.save, "figs",
                                                "{:04d}.jpg".format(epoch))
                    utils.makedirs(os.path.dirname(fig_filename))
                    generated_samples, _, _ = model(fixed_z, reverse=True)
                    generated_samples = generated_samples.view(-1, *data_shape)
                    nb = int(np.ceil(np.sqrt(float(fixed_z.size(0)))))
                    save_image(unshift(generated_samples, nbits=args.nbits),
                               fig_filename,
                               nrow=nb)
            if args.validate:
                break
コード例 #3
0
def main():
    os.system('shutdown -c')  # cancel previous shutdown command
    log.console(args)
    tb.log('sizes/world', dist_utils.env_world_size())

    # need to index validation directory before we start counting the time
    dataloader.sort_ar(args.data + '/validation')

    if args.distributed:
        log.console('Distributed initializing process group')
        torch.cuda.set_device(args.local_rank)
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=dist_utils.env_world_size())
        assert (dist_utils.env_world_size() == dist.get_world_size())
        log.console("Distributed: success (%d/%d)" %
                    (args.local_rank, dist.get_world_size()))

    log.console("Loading model")
    model = resnet.resnet50(bn0=args.init_bn0).cuda()
    if args.fp16: model = network_to_half(model)
    if args.distributed:
        model = dist_utils.DDP(model,
                               device_ids=[args.local_rank],
                               output_device=args.local_rank)
    best_top5 = 93  # only save models over 93%. Otherwise it stops to save every time

    global model_params, master_params
    if args.fp16: model_params, master_params = prep_param_lists(model)
    else: model_params = master_params = model.parameters()

    optim_params = experimental_utils.bnwd_optim_params(
        model, model_params, master_params) if args.no_bn_wd else master_params

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(
        optim_params,
        0,
        momentum=args.momentum,
        weight_decay=args.weight_decay
    )  # start with 0 lr. Scheduler will change this later

    if args.resume:
        checkpoint = torch.load(
            args.resume,
            map_location=lambda storage, loc: storage.cuda(args.local_rank))
        model.load_state_dict(checkpoint['state_dict'])
        args.start_epoch = checkpoint['epoch']
        best_top5 = checkpoint['best_top5']
        optimizer.load_state_dict(checkpoint['optimizer'])

    # save script so we can reproduce from logs
    shutil.copy2(os.path.realpath(__file__), f'{args.logdir}')

    log.console(
        "Creating data loaders (this could take up to 10 minutes if volume needs to be warmed up)"
    )
    phases = eval(args.phases)
    dm = DataManager([copy.deepcopy(p) for p in phases if 'bs' in p])
    scheduler = Scheduler(optimizer,
                          [copy.deepcopy(p) for p in phases if 'lr' in p])

    start_time = datetime.now()  # Loading start to after everything is loaded
    if args.evaluate:
        return validate(dm.val_dl, model, criterion, 0, start_time)

    if args.distributed:
        log.console('Syncing machines before training')
        dist_utils.sum_tensor(torch.tensor([1.0]).float().cuda())

    log.event("~~epoch\thours\ttop1\ttop5\n")
    for epoch in range(args.start_epoch, scheduler.tot_epochs):
        dm.set_epoch(epoch)

        train(dm.trn_dl, model, criterion, optimizer, scheduler, epoch)
        top1, top5 = validate(dm.val_dl, model, criterion, epoch, start_time)

        time_diff = (datetime.now() - start_time).total_seconds() / 3600.0
        log.event(f'~~{epoch}\t{time_diff:.5f}\t\t{top1:.3f}\t\t{top5:.3f}\n')

        is_best = top5 > best_top5
        best_top5 = max(top5, best_top5)
        if args.local_rank == 0:
            if is_best:
                save_checkpoint(epoch,
                                model,
                                best_top5,
                                optimizer,
                                is_best=True,
                                filename='model_best.pth.tar')
            phase = dm.get_phase(epoch)
            if phase:
                save_checkpoint(
                    epoch,
                    model,
                    best_top5,
                    optimizer,
                    filename=f'sz{phase["bs"]}_checkpoint.path.tar')
コード例 #4
0
def main():
    # os.system('sudo shutdown -c')  # cancel previous shutdown command
    log.console(args)
    tb.log('sizes/world', dist_utils.env_world_size())

    print(args.data)
    assert os.path.exists(args.data)

    # need to index validation directory before we start counting the time
    dataloader.sort_ar(args.data + '/val')

    if args.distributed:
        log.console('Distributed initializing process group')
        torch.cuda.set_device(args.local_rank)
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=dist_utils.env_world_size())
        assert (dist_utils.env_world_size() == dist.get_world_size())
        # todo(y): use global_rank instead of local_rank here
        log.console("Distributed: success (%d/%d)" %
                    (args.local_rank, dist.get_world_size()))

    log.console("Loading model")
    #from mobilenetv3 import MobileNetV3
    #model = MobileNetV3(mode='small', num_classes=1000).cuda()
    if args.network == 'resnet50':
        model.resnet.resnet50(bn0=args.init_bn0).cuda()
    elif args.network == 'resnet50friendlyv1':
        model = resnet.resnet50friendly(bn0=args.init_bn0, hybrid=True).cuda()
    elif args.network == 'resnet50friendlyv2':
        model = resnet.resnet50friendly2(bn0=args.init_bn0, hybrid=True).cuda()
    elif args.network == 'resnet50friendlyv3':
        model = resnet.resnet50friendly3(bn0=args.init_bn0, hybrid=True).cuda()
    elif args.network == 'resnet50friendlyv4':
        model = resnet.resnet50friendly4(bn0=args.init_bn0, hybrid=True).cuda()
    #import resnet_friendly
    #model = resnet_friendly.ResNet50Friendly().cuda()
    #model = torchvision.models.mobilenet_v2(pretrained=False).cuda()
    if args.fp16:
        model = network_to_half(model)
    if args.distributed:
        model = dist_utils.DDP(model,
                               device_ids=[args.local_rank],
                               output_device=args.local_rank)
    best_top5 = 93  # only save models over 93%. Otherwise it stops to save every time

    global model_params, master_params
    if args.fp16:
        model_params, master_params = prep_param_lists(model)
    else:
        model_params = master_params = model.parameters()

    optim_params = experimental_utils.bnwd_optim_params(
        model, model_params, master_params) if args.no_bn_wd else master_params

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(
        optim_params,
        0,
        momentum=args.momentum,
        weight_decay=args.weight_decay
    )  # start with 0 lr. Scheduler will change this later

    if args.resume:
        checkpoint = torch.load(
            args.resume,
            map_location=lambda storage, loc: storage.cuda(args.local_rank))
        model.load_state_dict(checkpoint['state_dict'])
        args.start_epoch = checkpoint['epoch']
        current_phase = checkpoint['current_phase']
        best_top5 = checkpoint['best_top5']
        optimizer.load_state_dict(checkpoint['optimizer'])

    # save script so we can reproduce from logs
    shutil.copy2(os.path.realpath(__file__), f'{args.logdir}')

    log.console(
        "Creating data loaders (this could take up to 10 minutes if volume needs to be warmed up)"
    )
    # phases = util.text_unpickle(args.phases)
    lr = 0.9
    scale_224 = 224 / 512
    scale_288 = 128 / 512
    one_machine = [
        {
            'ep': 0,
            'sz': 128,
            'bs': 512,
            'trndir': ''
        },  # Will this work?  -- No idea! Should we try with mv2 baseline? ???
        {
            'ep': (0, 5),
            'lr': (lr, lr * 2)
        },  # lr warmup is better with --init-bn0
        {
            'ep': 5,
            'lr': lr
        },
        {
            'ep': 14,
            'sz': 224,
            'bs': 224,
            'lr': lr * scale_224
        },
        {
            'ep': 16,
            'lr': lr / 10 * scale_224
        },
        {
            'ep': 32,
            'lr': lr / 100 * scale_224
        },
        {
            'ep': 37,
            'lr': lr / 100 * scale_224
        },
        {
            'ep': 39,
            'sz': 288,
            'bs': 128,
            'min_scale': 0.5,
            'rect_val': True,
            'lr': lr / 100 * scale_288
        },
        {
            'ep': (40, 44),
            'lr': lr / 1000 * scale_288
        },
        #{'ep': (36, 40), 'lr': lr / 1000 * scale_288},
        {
            'ep': (45, 48),
            'lr': lr / 10000 * scale_288
        },
        {
            'ep': (49, 52),
            'sz': 288,
            'bs': 224,
            'lr': lr / 10000 * scale_224
        }
        #{'ep': (46, 50), 'sz': 320, 'bs': 64,  'lr': lr / 10000 * scale_320}
    ]
    phases = util.text_pickle(one_machine)  #Ok? Unpickle?
    phases = util.text_unpickle(phases)
    dm = DataManager([copy.deepcopy(p) for p in phases if 'bs' in p])
    scheduler = Scheduler(optimizer,
                          [copy.deepcopy(p) for p in phases if 'lr' in p])

    start_time = datetime.now()  # Loading start to after everything is loaded
    if args.evaluate:
        return validate(dm.val_dl, model, criterion, 0, start_time)

    if args.distributed:
        log.console('Syncing machines before training')
        dist_utils.sum_tensor(torch.tensor([1.0]).float().cuda())

    log.event("~~epoch\thours\ttop1\ttop5\n")
    for epoch in range(args.start_epoch, scheduler.tot_epochs):
        print(" The start epoch:", args.start_epoch)
        dm.set_epoch(epoch)

        train(dm.trn_dl, model, criterion, optimizer, scheduler, epoch)
        top1, top5 = validate(dm.val_dl, model, criterion, epoch, start_time)

        time_diff = (datetime.now() - start_time).total_seconds() / 3600.0
        log.event(f'~~{epoch}\t{time_diff:.5f}\t\t{top1:.3f}\t\t{top5:.3f}\n')

        is_best = top5 > best_top5
        best_top5 = max(top5, best_top5)
        phase_save = dm.get_phase(epoch)
        if args.local_rank == 0:
            if is_best:
                save_checkpoint(phase_save,
                                epoch,
                                model,
                                best_top5,
                                optimizer,
                                is_best=True,
                                filename='model_best_' + args.network +
                                args.name + '.pth.tar')
            else:
                save_checkpoint(phase_save,
                                epoch,
                                model,
                                top5,
                                optimizer,
                                is_best=False,
                                filename='model_epoch_latest_' + args.network +
                                args.name + '.pth.tar')
            phase = dm.get_phase(epoch)
            if phase:
                save_checkpoint(
                    phase_save,
                    epoch,
                    model,
                    best_top5,
                    optimizer,
                    filename=f'sz{phase["bs"]}_checkpoint.path.tar')