Exemplo n.º 1
0
    def __init__(self):
        self.tool = Tool(hparams.sens_num, hparams.key_len, hparams.sen_len,
                         hparams.poem_len, 0.0)
        self.tool.load_dic(hparams.vocab_path, hparams.ivocab_path)
        vocab_size = self.tool.get_vocab_size()
        print("vocabulary size: %d" % (vocab_size))
        PAD_ID = self.tool.get_PAD_ID()
        B_ID = self.tool.get_B_ID()
        assert vocab_size > 0 and PAD_ID >= 0 and B_ID >= 0
        self.hps = hparams._replace(vocab_size=vocab_size,
                                    pad_idx=PAD_ID,
                                    bos_idx=B_ID)

        # load model
        model = MixPoetAUS(self.hps)

        # load trained model
        utils.restore_checkpoint(self.hps.model_dir, device, model)
        self.model = model.to(device)
        self.model.eval()

        #utils.print_parameter_list(self.model)

        # load poetry filter
        print("loading poetry filter...")
        self.filter = PoetryFilter(self.tool.get_vocab(),
                                   self.tool.get_ivocab(), self.hps.data_dir)
        print("--------------------------")
Exemplo n.º 2
0
def main(device=torch.device('cuda:0')):
    # CLI arguments
    parser = arg.ArgumentParser(
        description='We all know what we are doing. Fighting!')
    parser.add_argument("--datasize",
                        "-d",
                        default="small",
                        type=str,
                        help="data size you want to use, small, medium, total")
    # Parsing
    args = parser.parse_args()
    # Data loaders
    datasize = args.datasize
    pathname = "data/nyu.zip"
    tr_loader, va_loader, te_loader = getTrainingValidationTestingData(
        datasize, pathname, batch_size=config("unet.batch_size"))

    # Model
    model = Net()

    # define loss function
    # criterion = torch.nn.L1Loss()

    # Attempts to restore the latest checkpoint if exists
    print("Loading unet...")
    model, start_epoch, stats = utils.restore_checkpoint(
        model, utils.config("unet.checkpoint"))
    acc, loss = utils.evaluate_model(model, te_loader, device)
    # axes = util.make_training_plot()
    print(f'Test Accuracy:{acc}')
    print(f'Test Loss:{loss}')
Exemplo n.º 3
0
def main(device=torch.device('cuda:0')):
    # CLI arguments
    parser = arg.ArgumentParser(
        description='We all know what we are doing. Fighting!')
    parser.add_argument("--datasize",
                        "-d",
                        default="small",
                        type=str,
                        help="data size you want to use, small, medium, total")
    # Parsing
    args = parser.parse_args()
    # Data loaders
    datasize = args.datasize
    pathname = "data/nyu.zip"
    tr_loader, va_loader, te_loader = getTrainingValidationTestingData(
        datasize, pathname, batch_size=config("unet.batch_size"))

    # Model
    #model = Net()
    #model = Dense121()
    model = Dense169()
    model = model.to(device)

    # define loss function
    # criterion = torch.nn.L1Loss()

    # Attempts to restore the latest checkpoint if exists
    print("Loading unet...")
    model, start_epoch, stats = utils.restore_checkpoint(
        model, utils.config("unet.checkpoint"))
    acc, loss = utils.evaluate_model(model, te_loader, device)
    # axes = util.make_training_plot()
    print(f'Test Error:{acc}')
    print(f'Test Loss:{loss}')

    # Get Test Images
    img_list = glob("examples/" + "*.png")

    # Set model to eval mode
    model.eval()
    model = model.to(device)

    # Begin testing loop
    print("Begin Test Loop ...")

    for idx, img_name in enumerate(img_list):

        img = load_images([img_name])
        img = torch.Tensor(img).float().to(device)
        print("Processing {}, Tensor Shape: {}".format(img_name, img.shape))

        with torch.no_grad():
            preds = model(img).squeeze(0)

        output = colorize(preds.data)
        output = output.transpose((1, 2, 0))
        cv2.imwrite(img_name.split(".")[0] + "_result.png", output)

        print("Processing {} done.".format(img_name))
Exemplo n.º 4
0
def main(device=torch.device('cuda:0')):
    # CLI arguments
    parser = arg.ArgumentParser(
        description='We all know what we are doing. Fighting!')
    parser.add_argument("--datasize",
                        "-d",
                        default="small",
                        type=str,
                        help="data size you want to use, small, medium, total")
    # Parsing
    args = parser.parse_args()
    # Data loaders

    # TODO:
    ####### Enter the model selection here! #####
    modelSelection = input(
        'Please input the type of model to be used(res50,dense121,dense169,mob_v2,mob):'
    )

    datasize = args.datasize
    filename = "nyu_new.zip"
    pathname = f"data/{filename}"
    csv = "data/nyu_csv.zip"
    te_loader = getTestingData(datasize,
                               csv,
                               pathname,
                               batch_size=config(modelSelection +
                                                 ".batch_size"))

    # Model
    if modelSelection.lower() == 'res50':
        model = Res50()
    elif modelSelection.lower() == 'dense121':
        model = Dense121()
    elif modelSelection.lower() == 'mob_v2':
        model = Mob_v2()
    elif modelSelection.lower() == 'dense169':
        model = Dense169()
    elif modelSelection.lower() == 'mob':
        model = Net()
    elif modelSelection.lower() == 'squeeze':
        model = Squeeze()
    else:
        assert False, 'Wrong type of model selection string!'
    model = model.to(device)

    # define loss function
    # criterion = torch.nn.L1Loss()

    # Attempts to restore the latest checkpoint if exists
    print(f"Loading {mdoelSelection}...")
    model, start_epoch, stats = utils.restore_checkpoint(
        model, utils.config(modelSelection + ".checkpoint"))
    acc, loss = utils.evaluate_model(model, te_loader, device, test=True)
    # axes = util.make_training_plot()
    print(f'Test Error:{acc}')
    print(f'Test Loss:{loss}')
Exemplo n.º 5
0
def train(mixpoet, tool, hps):
    last_epoch = utils.restore_checkpoint(hps.model_dir, device, mixpoet)

    if last_epoch is not None:
         print ("checkpoint exsits! directly recover!")
    else:
         print ("checkpoint not exsits! train from scratch!")

    mix_trainer = MixTrainer(hps)
    mix_trainer.train(mixpoet, tool)
Exemplo n.º 6
0
def main(device=torch.device('cuda:0')):
    # Model
    modelSelection = input(
        'Please input the type of model to be used(res50,dense121,dense169,dense161,mob_v2,mob):'
    )
    if modelSelection.lower() == 'res50':
        model = Res50()
    elif modelSelection.lower() == 'dense121':
        model = Dense121()
    elif modelSelection.lower() == 'dense161':
        model = Dense161()
    elif modelSelection.lower() == 'mob_v2':
        model = Mob_v2()
    elif modelSelection.lower() == 'dense169':
        model = Dense169()
    elif modelSelection.lower() == 'mob':
        model = Net()
    elif modelSelection.lower() == 'squeeze':
        model = Squeeze()
    else:
        assert False, 'Wrong type of model selection string!'
    model = model.to(device)

    # Attempts to restore the latest checkpoint if exists
    print("Loading unet...")
    model, start_epoch, stats = utils.restore_checkpoint(
        model, utils.config(modelSelection + ".checkpoint"))

    # Get Test Images
    img_list = glob("examples/" + "*.png")

    # Set model to eval mode
    model.eval()
    model = model.to(device)

    # Begin testing loop
    print("Begin Test Loop ...")

    for idx, img_name in enumerate(img_list):

        img = load_images([img_name])
        img = torch.Tensor(img).float().to(device)
        print("Processing {}, Tensor Shape: {}".format(img_name, img.shape))

        with torch.no_grad():
            preds = model(img).squeeze(0)

        output = colorize(preds.data)
        output = output.transpose((1, 2, 0))
        cv2.imwrite(
            img_name.split(".")[0] + "_" + modelSelection + "_result.png",
            output)

        print("Processing {} done.".format(img_name))
Exemplo n.º 7
0
def train(wm_model, tool, hps, specified_device):
    last_epoch = utils.restore_checkpoint(
        hps.model_dir, specified_device, wm_model)

    if last_epoch is not None:
         print ("checkpoint exsits! directly recover!")
    else:
         print ("checkpoint not exsits! train from scratch!")

    wm_trainer = WMTrainer(hps, specified_device)
    wm_trainer.train(wm_model, tool)
Exemplo n.º 8
0
    def __init__(self, hps, device):
        self.tool = Tool(hps.sens_num, hps.sen_len, hps.key_len,
                         hps.topic_slots, 0.0)
        self.tool.load_dic(hps.vocab_path, hps.ivocab_path)
        vocab_size = self.tool.get_vocab_size()
        print("vocabulary size: %d" % (vocab_size))
        PAD_ID = self.tool.get_PAD_ID()
        B_ID = self.tool.get_B_ID()
        assert vocab_size > 0 and PAD_ID >= 0 and B_ID >= 0
        self.hps = hps._replace(vocab_size=vocab_size,
                                pad_idx=PAD_ID,
                                bos_idx=B_ID)
        self.device = device

        # load model
        model = WorkingMemoryModel(self.hps, device)

        # load trained model
        utils.restore_checkpoint(self.hps.model_dir, device, model)
        self.model = model.to(device)
        self.model.eval()

        null_idxes = self.tool.load_function_tokens(self.hps.data_dir +
                                                    "fchars.txt").to(
                                                        self.device)
        self.model.set_null_idxes(null_idxes)

        self.model.set_tau(hps.min_tau)

        # load poetry filter
        print("loading poetry filter...")
        self.filter = PoetryFilter(self.tool.get_vocab(),
                                   self.tool.get_ivocab(), self.hps.data_dir)

        self.visual_tool = Visualization(hps.topic_slots, hps.his_mem_slots,
                                         "../log/")
        print("--------------------------")
Exemplo n.º 9
0
def evaluate_checkpoint_on(restore_checkpoint,
                           dataset_cfg,
                           _run,
                           model_update_cfg={}):
    model_cfg, _, epoch = utils.restore_checkpoint(restore_checkpoint,
                                                   model_cfg=model_update_cfg,
                                                   map_location='cpu')
    #model_cfg['backbone']['output_dim'] = 256
    dataloaders = dataloader_builder.build(dataset_cfg)
    model = model_builder.build(model_cfg)
    # TODO needs to be from dataset
    if 'seg_class_mapping' in model_cfg:
        mapping = model_cfg['seg_class_mapping']
    else:
        mapping = None

    model.seg_mapping = mapping

    model = torch.nn.DataParallel(model, device_ids=_run.config['device_id'])
    model = model.cuda()
    return evaluate(dataloaders, model, epoch, keep=True)
Exemplo n.º 10
0
optimizer = torch.optim.Adam([{
    'params': net.encoder.parameters(),
    'weight_decay': 1e-2
}, {
    'params': net.decoder.parameters(),
    'weight_decay': 0
}],
                             lr=1e-4,
                             eps=1e-6)

scheduler = StepLR(optimizer, step_size=1, gamma=0.8)
test_loss = 999999
epoch_loss = 999999

if restore_check is True:
    net, optimizer, scheduler, epoch_loss, test_loss = restore_checkpoint(
        net, optimizer, scheduler, masked, recent=True, inception=False)

for param in optimizer.param_groups:
    lr = param['lr']

trainNet(net,
         batch_size=8,
         n_epochs=100,
         learning_rate=lr,
         last_epoch_loss=epoch_loss,
         last_loss=test_loss,
         optimizer=optimizer,
         scheduler=scheduler,
         save=True)
Exemplo n.º 11
0
def main(device, tr_loader, va_loader, te_loader, modelSelection):
    """Train CNN and show training plots."""
    # CLI arguments
    # parser = arg.ArgumentParser(description='We all know what we are doing. Fighting!')
    # parser.add_argument("--datasize", "-d", default="small", type=str,
    #                     help="data size you want to use, small, medium, total")
    # Parsing
    # args = parser.parse_args()
    # Data loaders
    # datasize = args.datasize
    # Model
    if modelSelection.lower() == 'res50':
        model = Res50()
    elif modelSelection.lower() == 'dense121':
        model = Dense121()
    elif modelSelection.lower() == 'mobv2':
        model = Mob_v2()
    elif modelSelection.lower() == 'dense169':
        model = Dense169()
    elif modelSelection.lower() == 'mob':
        model = Net()
    elif modelSelection.lower() == 'squeeze':
        model = Squeeze()
    else:
        assert False, 'Wrong type of model selection string!'
    # Model
    # model = Net()
    # model = Squeeze()
    model = model.to(device)

    # TODO: define loss function, and optimizer
    learning_rate = utils.config(modelSelection + ".learning_rate")
    criterion = DepthLoss(0.1).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    number_of_epoches = 10
    #

    # Attempts to restore the latest checkpoint if exists
    print("Loading unet...")
    model, start_epoch, stats = utils.restore_checkpoint(
        model, utils.config(modelSelection + ".checkpoint"))

    running_va_loss = [] if 'va_loss' not in stats else stats['va_loss']
    running_va_acc = [] if 'va_err' not in stats else stats['va_err']
    running_tr_loss = [] if 'tr_loss' not in stats else stats['tr_loss']
    running_tr_acc = [] if 'tr_err' not in stats else stats['tr_err']
    tr_acc, tr_loss = utils.evaluate_model(model, tr_loader, device)
    acc, loss = utils.evaluate_model(model, va_loader, device)
    running_va_acc.append(acc)
    running_va_loss.append(loss)
    running_tr_acc.append(tr_acc)
    running_tr_loss.append(tr_loss)
    stats = {
        'va_err': running_va_acc,
        'va_loss': running_va_loss,
        'tr_err': running_tr_acc,
        'tr_loss': running_tr_loss,
        # 'num_of_epoch': 0
    }
    # Loop over the entire dataset multiple times
    # for epoch in range(start_epoch, config('cnn.num_epochs')):
    epoch = start_epoch
    # while curr_patience < patience:
    while epoch < number_of_epoches:
        # Train model
        utils.train_epoch(device, tr_loader, model, criterion, optimizer)
        # Save checkpoint
        utils.save_checkpoint(model, epoch + 1,
                              utils.config(modelSelection + ".checkpoint"),
                              stats)
        # Evaluate model
        tr_acc, tr_loss = utils.evaluate_model(model, tr_loader, device)
        va_acc, va_loss = utils.evaluate_model(model, va_loader, device)
        running_va_acc.append(va_acc)
        running_va_loss.append(va_loss)
        running_tr_acc.append(tr_acc)
        running_tr_loss.append(tr_loss)
        epoch += 1
    print("Finished Training")
    utils.make_plot(running_tr_loss, running_tr_acc, running_va_loss,
                    running_va_acc)
Exemplo n.º 12
0
def play(args):
    env = create_mario_env(args.env_name, ACTIONS[args.move_set])

    observation_space = env.observation_space.shape[0]
    action_space = env.action_space.n

    model = ActorCritic(observation_space, action_space)

    checkpoint_file = \
        f"{args.env_name}/{args.model_id}_{args.algorithm}_params.tar"
    checkpoint = restore_checkpoint(checkpoint_file)
    assert args.env_name == checkpoint['env'], \
        "This checkpoint is for different environment: {checkpoint['env']}"
    args.model_id = checkpoint['id']

    print(f"Environment: {args.env_name}")
    print(f"      Agent: {args.model_id}")
    model.load_state_dict(checkpoint['model_state_dict'])

    state = env.reset()
    state = torch.from_numpy(state)
    reward_sum = 0
    done = True
    episode_length = 0
    start_time = time.time()
    for step in count():
        episode_length += 1

        # shared model sync
        if done:
            cx = torch.zeros(1, 512)
            hx = torch.zeros(1, 512)

        else:
            cx = cx.data
            hx = hx.data

        with torch.no_grad():
            value, logit, (hx, cx) = model((state.unsqueeze(0), (hx, cx)))

        prob = F.softmax(logit, dim=-1)
        action = prob.max(-1, keepdim=True)[1]

        action_idx = action.item()
        action_out = ACTIONS[args.move_set][action_idx]
        state, reward, done, info = env.step(action_idx)
        reward_sum += reward

        print(
            f"{emojize(':mushroom:')} World {info['world']}-{info['stage']} | {emojize(':video_game:')}: [ {' + '.join(action_out):^13s} ] | ",
            end='\r',
        )

        env.render()

        if done:
            t = time.time() - start_time

            print(
                f"{emojize(':mushroom:')} World {info['world']}-{info['stage']} |" + \
                f" {emojize(':video_game:')}: [ {' + '.join(action_out):^13s} ] | " + \
                f"ID: {args.model_id}, " + \
                f"Time: {time.strftime('%H:%M:%S', time.gmtime(t)):^9s}, " + \
                f"Reward: {reward_sum: 10.2f}, " + \
                f"Progress: {(info['x_pos'] / 3225) * 100: 3.2f}%",
                end='\r',
                flush=True,
            )

            reward_sum = 0
            episode_length = 0
            time.sleep(args.reset_delay)
            state = env.reset()

        state = torch.from_numpy(state)
Exemplo n.º 13
0
def main(args):
    print(f" Session ID: {args.uuid}")

    # logging
    log_dir = f'logs/{args.env_name}/{args.model_id}/{args.uuid}/'
    args_logger = setup_logger('args', log_dir, f'args.log')
    env_logger = setup_logger('env', log_dir, f'env.log')

    if args.debug:
        debug.packages()
    os.environ['OMP_NUM_THREADS'] = "1"
    if torch.cuda.is_available():
        os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
        devices = ",".join([str(i) for i in range(torch.cuda.device_count())])
        os.environ["CUDA_VISIBLE_DEVICES"] = devices

    args_logger.info(vars(args))
    env_logger.info(vars(os.environ))

    env = create_atari_environment(args.env_name)

    shared_model = ActorCritic(env.observation_space.shape[0],
                               env.action_space.n)

    if torch.cuda.is_available():
        shared_model = shared_model.cuda()

    shared_model.share_memory()

    optimizer = SharedAdam(shared_model.parameters(), lr=args.lr)
    optimizer.share_memory()

    if args.load_model:  # TODO Load model before initializing optimizer
        checkpoint_file = f"{args.env_name}/{args.model_id}_{args.algorithm}_params.tar"
        checkpoint = restore_checkpoint(checkpoint_file)
        assert args.env_name == checkpoint['env'], \
            "Checkpoint is for different environment"
        args.model_id = checkpoint['id']
        args.start_step = checkpoint['step']
        print("Loading model from checkpoint...")
        print(f"Environment: {args.env_name}")
        print(f"      Agent: {args.model_id}")
        print(f"      Start: Step {args.start_step}")
        shared_model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    else:
        print(f"Environment: {args.env_name}")
        print(f"      Agent: {args.model_id}")

    torch.manual_seed(args.seed)

    print(
        FontColor.BLUE + \
        f"CPUs:    {mp.cpu_count(): 3d} | " + \
        f"GPUs: {None if not torch.cuda.is_available() else torch.cuda.device_count()}" + \
        FontColor.END
    )

    processes = []

    counter = mp.Value('i', 0)
    lock = mp.Lock()

    # Queue training processes
    num_processes = args.num_processes
    no_sample = args.non_sample  # count of non-sampling processes

    if args.num_processes > 1:
        num_processes = args.num_processes - 1

    samplers = num_processes - no_sample

    for rank in range(0, num_processes):
        device = 'cpu'
        if torch.cuda.is_available():
            device = 0  # TODO: Need to move to distributed to handle multigpu
        if rank < samplers:  # random action
            p = mp.Process(
                target=train,
                args=(rank, args, shared_model, counter, lock, optimizer,
                      device),
            )
        else:  # best action
            p = mp.Process(
                target=train,
                args=(rank, args, shared_model, counter, lock, optimizer,
                      device, False),
            )
        p.start()
        time.sleep(1.)
        processes.append(p)

    # Queue test process
    p = mp.Process(target=test,
                   args=(args.num_processes, args, shared_model, counter, 0))

    p.start()
    processes.append(p)

    for p in processes:
        p.join()
Exemplo n.º 14
0
        for i, data in enumerate(dataset, 0):
            inputs, labels, mask = data
            inputs, labels, mask = inputs.to(device), labels.to(
                device), mask.to(device)
            outputs = net(inputs.float())
            loss_size = loss(outputs, labels, mask)
            total_loss[index] = total_loss[index] + loss_size.data

            if debug:
                if (i + 1) % (len(dataset) // 3 + 1) == 0:
                    print("{:d}%".format(int(i / len(dataset) * 100)))
                if (i + 1) % (len(dataset) // 10 + 1) == 0:
                    draw_debug(inputs, labels, mask, outputs, name='test')

        total_loss[index] = total_loss[index] / len(dataset)

    print('Test results:\n\trel: {:.6f}'.format(total_loss[0]))
    return total_loss[0]


if __name__ == '__main__':
    net = InceptionResNetV2().to(device)
    net, _, _, _, _ = utils.restore_checkpoint(net,
                                               None,
                                               None,
                                               inception=True,
                                               masked=True,
                                               recent=False)
    net.eval()
    print(testNet(net))
Exemplo n.º 15
0
def main(device=torch.device('cuda:0')):
    # CLI arguments
    parser = arg.ArgumentParser(
        description='We all know what we are doing. Fighting!')
    parser.add_argument("--datasize",
                        "-d",
                        default="small",
                        type=str,
                        help="data size you want to use, small, medium, total")
    # Parsing
    args = parser.parse_args()
    # Data loaders
    datasize = args.datasize
    pathname = "data/nyu.zip"

    # Model
    modelSelection = input(
        'Please input the type of model to be used(res50,dense121,dense169,mob_v2,mob):'
    )

    # Model
    if modelSelection.lower() == 'res50':
        model = Res50()
    elif modelSelection.lower() == 'dense121':
        model = Dense121()
    elif modelSelection.lower() == 'mob_v2':
        model = Mob_v2()
    elif modelSelection.lower() == 'dense169':
        model = Dense169()
    elif modelSelection.lower() == 'mob':
        model = Net()
    elif modelSelection.lower() == 'squeeze':
        model = Squeeze()
    else:
        assert False, 'Wrong type of model selection string!'
    model = model.to(device)

    # Attempts to restore the latest checkpoint if exists
    print("Loading unet...")
    model, start_epoch, stats = utils.restore_checkpoint(
        model, utils.config(modelSelection + ".checkpoint"))

    # Get Test Images
    img_list = glob("examples/" + "*.png")

    # Set model to eval mode
    model.eval()
    model = model.to(device)

    # Begin testing loop
    print("Begin Test Loop ...")

    for idx, img_name in enumerate(img_list):

        img = load_images([img_name])
        img = torch.Tensor(img).float().to(device)
        print("Processing {}, Tensor Shape: {}".format(img_name, img.shape))

        with torch.no_grad():
            preds = model(img).squeeze(0)

        output = colorize(preds.data)
        output = output.transpose((1, 2, 0))
        cv2.imwrite(
            img_name.split(".")[0] + "_" + modelSelection + "_result.png",
            output)

        print("Processing {} done.".format(img_name))
Exemplo n.º 16
0
config.eval.batch_size = batch_size

random_seed = 0 #@param {"type": "integer"}

sigmas = mutils.get_sigmas(config)
scaler = datasets.get_data_scaler(config)
inverse_scaler = datasets.get_data_inverse_scaler(config)
score_model = mutils.create_model(config)

optimizer = get_optimizer(config, score_model.parameters())
ema = ExponentialMovingAverage(score_model.parameters(),
                               decay=config.model.ema_rate)
state = dict(step=0, optimizer=optimizer,
             model=score_model, ema=ema)

state = restore_checkpoint(ckpt_filename, state, config.device)
ema.copy_to(score_model.parameters())

#@title Visualization code

def image_grid(x):
  size = config.data.image_size
  channels = config.data.num_channels
  img = x.reshape(-1, size, size, channels)
  w = int(np.sqrt(img.shape[0]))
  img = img.reshape((w, w, size, size, channels)).transpose((0, 2, 1, 3, 4)).reshape((w * size, w * size, channels))
  return img

def show_samples(x):
  x = x.permute(0, 2, 3, 1).detach().cpu().numpy()
  img = image_grid(x)
Exemplo n.º 17
0
def run_train(dataloader_cfg, model_cfg, scheduler_cfg, optimizer_cfg,
              loss_cfg, validation_cfg, checkpoint_frequency,
              restore_checkpoint, max_epochs, _run):

    # Lets cuDNN benchmark conv implementations and choose the fastest.
    # Only good if sizes stay the same within the main loop!
    torch.backends.cudnn.benchmark = True
    exit_handler = ExitHandler()

    device = _run.config['device']
    device_id = _run.config['device_id']

    # during training just one dataloader
    dataloader = dataloader_builder.build(dataloader_cfg)[0]

    epoch = 0
    if restore_checkpoint is not None:
        model_cfg, optimizer_cfg, epoch = utils.restore_checkpoint(
            restore_checkpoint, model_cfg, optimizer_cfg)

    def overwrite(to_overwrite, dic):
        to_overwrite.update(dic)
        return to_overwrite

    # some models depend on dataset, for example num_joints
    model_cfg = overwrite(dataloader.dataset.info, model_cfg)
    model = model_builder.build(model_cfg)

    loss_cfg['model'] = model
    loss = loss_builder.build(loss_cfg)
    loss = loss.to(device)

    parameters = list(model.parameters()) + list(loss.parameters())
    optimizer = optimizer_builder.build(optimizer_cfg, parameters)

    lr_scheduler = scheduler_builder.build(scheduler_cfg, optimizer, epoch)

    if validation_cfg is None:
        validation_dataloaders = None
    else:
        validation_dataloaders = dataloader_builder.build(validation_cfg)
        keep = False

    file_logger = log.get_file_logger()
    logger = log.get_logger()

    model = torch.nn.DataParallel(model, device_ids=device_id)
    model.cuda()

    model = model.train()
    trained_models = []

    exit_handler.register(file_logger.save_checkpoint, model, optimizer,
                          "atexit", model_cfg)

    start_training_time = time.time()
    end = time.time()
    while epoch < max_epochs:
        epoch += 1
        lr_scheduler.step()
        logger.info("Starting Epoch %d/%d", epoch, max_epochs)
        len_batch = len(dataloader)
        acc_time = 0
        for batch_id, data in enumerate(dataloader):
            optimizer.zero_grad()
            endpoints = model(data, model.module.endpoints)
            logger.debug("datasets %s", list(data['split_info'].keys()))

            data.update(endpoints)
            # threoretically losses could also be caluclated distributed.
            losses = loss(endpoints, data)
            loss_mean = torch.mean(losses)
            loss_mean.backward()
            optimizer.step()

            acc_time += time.time() - end
            end = time.time()

            report_after_batch(_run=_run,
                               logger=logger,
                               batch_id=batch_id,
                               batch_len=len_batch,
                               acc_time=acc_time,
                               loss_mean=loss_mean,
                               max_mem=torch.cuda.max_memory_allocated())

        if epoch % checkpoint_frequency == 0:
            path = file_logger.save_checkpoint(model, optimizer, epoch,
                                               model_cfg)
            trained_models.append(path)

        report_after_epoch(_run=_run, epoch=epoch, max_epoch=max_epochs)

        if validation_dataloaders is not None and \
                epoch % checkpoint_frequency == 0:
            model.eval()

            # Lets cuDNN benchmark conv implementations and choose the fastest.
            # Only good if sizes stay the same within the main loop!
            # not the case for segmentation
            torch.backends.cudnn.benchmark = False
            score = evaluate(validation_dataloaders, model, epoch, keep=keep)
            logger.info(score)
            log_score(score, _run, prefix="val_", step=epoch)
            torch.backends.cudnn.benchmark = True
            model.train()

    report_after_training(_run=_run,
                          max_epoch=max_epochs,
                          total_time=time.time() - start_training_time)
    path = file_logger.save_checkpoint(model, optimizer, epoch, model_cfg)
    if path:
        trained_models.append(path)
    file_logger.close()
    # TODO get best performing val model
    evaluate_last = _run.config['training'].get('evaluate_last', 1)
    if len(trained_models) < evaluate_last:
        logger.info("Only saved %d models (evaluate_last=%d)",
                    len(trained_models), evaluate_last)
    return trained_models[-evaluate_last:]
Exemplo n.º 18
0
def main(device, tr_loader, va_loader, te_loader, modelSelection):
    """Train CNN and show training plots."""
    # Model
    if modelSelection.lower() == 'res50':
        model = Res50()
    elif modelSelection.lower() == 'dense121':
        model = Dense121()
    elif modelSelection.lower() == 'dense161':
        model = Dense161()
    elif modelSelection.lower() == 'mobv2':
        model = Mob_v2()
    elif modelSelection.lower() == 'dense169':
        model = Dense169()
    elif modelSelection.lower() == 'mob':
        model = Net()
    elif modelSelection.lower() == 'squeeze':
        model = Squeeze()
    else:
        assert False, 'Wrong type of model selection string!'
    model = model.to(device)

    # TODO: define loss function, and optimizer
    learning_rate = utils.config(modelSelection + ".learning_rate")
    criterion = DepthLoss(0.1).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    number_of_epoches = 10
    #

    # Attempts to restore the latest checkpoint if exists
    print("Loading unet...")
    model, start_epoch, stats = utils.restore_checkpoint(
        model, utils.config(modelSelection + ".checkpoint"))

    running_va_loss = [] if 'va_loss' not in stats else stats['va_loss']
    running_va_acc = [] if 'va_err' not in stats else stats['va_err']
    running_tr_loss = [] if 'tr_loss' not in stats else stats['tr_loss']
    running_tr_acc = [] if 'tr_err' not in stats else stats['tr_err']
    tr_acc, tr_loss = utils.evaluate_model(model, tr_loader, device)
    acc, loss = utils.evaluate_model(model, va_loader, device)
    running_va_acc.append(acc)
    running_va_loss.append(loss)
    running_tr_acc.append(tr_acc)
    running_tr_loss.append(tr_loss)
    stats = {
        'va_err': running_va_acc,
        'va_loss': running_va_loss,
        'tr_err': running_tr_acc,
        'tr_loss': running_tr_loss,
    }
    # Loop over the entire dataset multiple times
    # for epoch in range(start_epoch, config('cnn.num_epochs')):
    epoch = start_epoch
    # while curr_patience < patience:
    while epoch < number_of_epoches:
        # Train model
        utils.train_epoch(device, tr_loader, model, criterion, optimizer)
        # Save checkpoint
        utils.save_checkpoint(model, epoch + 1,
                              utils.config(modelSelection + ".checkpoint"),
                              stats)
        # Evaluate model
        tr_acc, tr_loss = utils.evaluate_model(model, tr_loader, device)
        va_acc, va_loss = utils.evaluate_model(model, va_loader, device)
        running_va_acc.append(va_acc)
        running_va_loss.append(va_loss)
        running_tr_acc.append(tr_acc)
        running_tr_loss.append(tr_loss)
        epoch += 1
    print("Finished Training")
    utils.make_plot(running_tr_loss, running_tr_acc, running_va_loss,
                    running_va_acc)
Exemplo n.º 19
0
def evaluate(config, workdir, eval_folder="eval"):
    """Evaluate trained models.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints.
    eval_folder: The subfolder for storing evaluation results. Default to
      "eval".
  """
    # Create directory to eval_folder
    eval_dir = os.path.join(workdir, eval_folder)
    tf.io.gfile.makedirs(eval_dir)

    # Build data pipeline
    train_ds, eval_ds, _ = datasets.get_dataset(
        config,
        uniform_dequantization=config.data.uniform_dequantization,
        evaluation=True)

    # Create data normalizer and its inverse
    scaler = datasets.get_data_scaler(config)
    inverse_scaler = datasets.get_data_inverse_scaler(config)

    # Initialize model
    score_model = mutils.create_model(config)
    optimizer = losses.get_optimizer(config, score_model.parameters())
    ema = ExponentialMovingAverage(score_model.parameters(),
                                   decay=config.model.ema_rate)
    state = dict(optimizer=optimizer, model=score_model, ema=ema, step=0)

    checkpoint_dir = os.path.join(workdir, "checkpoints")

    # Setup SDEs
    if config.training.sde.lower() == 'vpsde':
        sde = sde_lib.VPSDE(beta_min=config.model.beta_min,
                            beta_max=config.model.beta_max,
                            N=config.model.num_scales)
        sampling_eps = 1e-3
    elif config.training.sde.lower() == 'subvpsde':
        sde = sde_lib.subVPSDE(beta_min=config.model.beta_min,
                               beta_max=config.model.beta_max,
                               N=config.model.num_scales)
        sampling_eps = 1e-3
    elif config.training.sde.lower() == 'vesde':
        sde = sde_lib.VESDE(sigma_min=config.model.sigma_min,
                            sigma_max=config.model.sigma_max,
                            N=config.model.num_scales)
        sampling_eps = 1e-5
    else:
        raise NotImplementedError(f"SDE {config.training.sde} unknown.")

    # Create the one-step evaluation function when loss computation is enabled
    if config.eval.enable_loss:
        optimize_fn = losses.optimization_manager(config)
        continuous = config.training.continuous
        likelihood_weighting = config.training.likelihood_weighting

        reduce_mean = config.training.reduce_mean
        eval_step = losses.get_step_fn(
            sde,
            train=False,
            optimize_fn=optimize_fn,
            reduce_mean=reduce_mean,
            continuous=continuous,
            likelihood_weighting=likelihood_weighting)

    # Create data loaders for likelihood evaluation. Only evaluate on uniformly dequantized data
    train_ds_bpd, eval_ds_bpd, _ = datasets.get_dataset(
        config, uniform_dequantization=True, evaluation=True)
    if config.eval.bpd_dataset.lower() == 'train':
        ds_bpd = train_ds_bpd
        bpd_num_repeats = 1
    elif config.eval.bpd_dataset.lower() == 'test':
        # Go over the dataset 5 times when computing likelihood on the test dataset
        ds_bpd = eval_ds_bpd
        bpd_num_repeats = 5
    else:
        raise ValueError(
            f"No bpd dataset {config.eval.bpd_dataset} recognized.")

    # Build the likelihood computation function when likelihood is enabled
    if config.eval.enable_bpd:
        likelihood_fn = likelihood.get_likelihood_fn(sde, inverse_scaler)

    # Build the sampling function when sampling is enabled
    if config.eval.enable_sampling:
        sampling_shape = (config.eval.batch_size, config.data.num_channels,
                          config.data.image_size, config.data.image_size)
        sampling_fn = sampling.get_sampling_fn(config, sde, sampling_shape,
                                               inverse_scaler, sampling_eps)

    # Use inceptionV3 for images with resolution higher than 256.
    inceptionv3 = config.data.image_size >= 256
    inception_model = evaluation.get_inception_model(inceptionv3=inceptionv3)

    begin_ckpt = config.eval.begin_ckpt
    logging.info("begin checkpoint: %d" % (begin_ckpt, ))
    for ckpt in range(begin_ckpt, config.eval.end_ckpt + 1):
        # Wait if the target checkpoint doesn't exist yet
        waiting_message_printed = False
        ckpt_filename = os.path.join(checkpoint_dir,
                                     "checkpoint_{}.pth".format(ckpt))
        while not tf.io.gfile.exists(ckpt_filename):
            if not waiting_message_printed:
                logging.warning("Waiting for the arrival of checkpoint_%d" %
                                (ckpt, ))
                waiting_message_printed = True
            time.sleep(60)

        # Wait for 2 additional mins in case the file exists but is not ready for reading
        ckpt_path = os.path.join(checkpoint_dir, f'checkpoint_{ckpt}.pth')
        try:
            state = restore_checkpoint(ckpt_path, state, device=config.device)
        except:
            time.sleep(60)
            try:
                state = restore_checkpoint(ckpt_path,
                                           state,
                                           device=config.device)
            except:
                time.sleep(120)
                state = restore_checkpoint(ckpt_path,
                                           state,
                                           device=config.device)
        ema.copy_to(score_model.parameters())
        # Compute the loss function on the full evaluation dataset if loss computation is enabled
        if config.eval.enable_loss:
            all_losses = []
            eval_iter = iter(eval_ds)  # pytype: disable=wrong-arg-types
            for i, batch in enumerate(eval_iter):
                eval_batch = torch.from_numpy(batch['image']._numpy()).to(
                    config.device).float()
                eval_batch = eval_batch.permute(0, 3, 1, 2)
                eval_batch = scaler(eval_batch)
                eval_loss = eval_step(state, eval_batch)
                all_losses.append(eval_loss.item())
                if (i + 1) % 1000 == 0:
                    logging.info("Finished %dth step loss evaluation" %
                                 (i + 1))

            # Save loss values to disk or Google Cloud Storage
            all_losses = np.asarray(all_losses)
            with tf.io.gfile.GFile(
                    os.path.join(eval_dir, f"ckpt_{ckpt}_loss.npz"),
                    "wb") as fout:
                io_buffer = io.BytesIO()
                np.savez_compressed(io_buffer,
                                    all_losses=all_losses,
                                    mean_loss=all_losses.mean())
                fout.write(io_buffer.getvalue())

        # Compute log-likelihoods (bits/dim) if enabled
        if config.eval.enable_bpd:
            bpds = []
            for repeat in range(bpd_num_repeats):
                bpd_iter = iter(ds_bpd)  # pytype: disable=wrong-arg-types
                for batch_id in range(len(ds_bpd)):
                    batch = next(bpd_iter)
                    eval_batch = torch.from_numpy(batch['image']._numpy()).to(
                        config.device).float()
                    eval_batch = eval_batch.permute(0, 3, 1, 2)
                    eval_batch = scaler(eval_batch)
                    bpd = likelihood_fn(score_model, eval_batch)[0]
                    bpd = bpd.detach().cpu().numpy().reshape(-1)
                    bpds.extend(bpd)
                    logging.info(
                        "ckpt: %d, repeat: %d, batch: %d, mean bpd: %6f" %
                        (ckpt, repeat, batch_id, np.mean(np.asarray(bpds))))
                    bpd_round_id = batch_id + len(ds_bpd) * repeat
                    # Save bits/dim to disk or Google Cloud Storage
                    with tf.io.gfile.GFile(
                            os.path.join(
                                eval_dir,
                                f"{config.eval.bpd_dataset}_ckpt_{ckpt}_bpd_{bpd_round_id}.npz"
                            ), "wb") as fout:
                        io_buffer = io.BytesIO()
                        np.savez_compressed(io_buffer, bpd)
                        fout.write(io_buffer.getvalue())

        # Generate samples and compute IS/FID/KID when enabled
        if config.eval.enable_sampling:
            num_sampling_rounds = config.eval.num_samples // config.eval.batch_size + 1
            for r in range(num_sampling_rounds):
                logging.info("sampling -- ckpt: %d, round: %d" % (ckpt, r))

                # Directory to save samples. Different for each host to avoid writing conflicts
                this_sample_dir = os.path.join(eval_dir, f"ckpt_{ckpt}")
                tf.io.gfile.makedirs(this_sample_dir)
                samples, n = sampling_fn(score_model)
                samples = np.clip(
                    samples.permute(0, 2, 3, 1).cpu().numpy() * 255., 0,
                    255).astype(np.uint8)
                samples = samples.reshape(
                    (-1, config.data.image_size, config.data.image_size,
                     config.data.num_channels))
                # Write samples to disk or Google Cloud Storage
                with tf.io.gfile.GFile(
                        os.path.join(this_sample_dir, f"samples_{r}.npz"),
                        "wb") as fout:
                    io_buffer = io.BytesIO()
                    np.savez_compressed(io_buffer, samples=samples)
                    fout.write(io_buffer.getvalue())

                # Force garbage collection before calling TensorFlow code for Inception network
                gc.collect()
                latents = evaluation.run_inception_distributed(
                    samples, inception_model, inceptionv3=inceptionv3)
                # Force garbage collection again before returning to JAX code
                gc.collect()
                # Save latent represents of the Inception network to disk or Google Cloud Storage
                with tf.io.gfile.GFile(
                        os.path.join(this_sample_dir, f"statistics_{r}.npz"),
                        "wb") as fout:
                    io_buffer = io.BytesIO()
                    np.savez_compressed(io_buffer,
                                        pool_3=latents["pool_3"],
                                        logits=latents["logits"])
                    fout.write(io_buffer.getvalue())

            # Compute inception scores, FIDs and KIDs.
            # Load all statistics that have been previously computed and saved for each host
            all_logits = []
            all_pools = []
            this_sample_dir = os.path.join(eval_dir, f"ckpt_{ckpt}")
            stats = tf.io.gfile.glob(
                os.path.join(this_sample_dir, "statistics_*.npz"))
            for stat_file in stats:
                with tf.io.gfile.GFile(stat_file, "rb") as fin:
                    stat = np.load(fin)
                    if not inceptionv3:
                        all_logits.append(stat["logits"])
                    all_pools.append(stat["pool_3"])

            if not inceptionv3:
                all_logits = np.concatenate(all_logits,
                                            axis=0)[:config.eval.num_samples]
            all_pools = np.concatenate(all_pools,
                                       axis=0)[:config.eval.num_samples]

            # Load pre-computed dataset statistics.
            data_stats = evaluation.load_dataset_stats(config)
            data_pools = data_stats["pool_3"]

            # Compute FID/KID/IS on all samples together.
            if not inceptionv3:
                inception_score = tfgan.eval.classifier_score_from_logits(
                    all_logits)
            else:
                inception_score = -1

            fid = tfgan.eval.frechet_classifier_distance_from_activations(
                data_pools, all_pools)
            # Hack to get tfgan KID work for eager execution.
            tf_data_pools = tf.convert_to_tensor(data_pools)
            tf_all_pools = tf.convert_to_tensor(all_pools)
            kid = tfgan.eval.kernel_classifier_distance_from_activations(
                tf_data_pools, tf_all_pools).numpy()
            del tf_data_pools, tf_all_pools

            logging.info(
                "ckpt-%d --- inception_score: %.6e, FID: %.6e, KID: %.6e" %
                (ckpt, inception_score, fid, kid))

            with tf.io.gfile.GFile(
                    os.path.join(eval_dir, f"report_{ckpt}.npz"), "wb") as f:
                io_buffer = io.BytesIO()
                np.savez_compressed(io_buffer,
                                    IS=inception_score,
                                    fid=fid,
                                    kid=kid)
                f.write(io_buffer.getvalue())
def main(device=torch.device('cuda:0')):
    # CLI arguments
    parser = arg.ArgumentParser(
        description='We all know what we are doing. Fighting!')
    parser.add_argument("--datasize",
                        "-d",
                        default="small",
                        type=str,
                        help="data size you want to use, small, medium, total")
    # Parsing
    args = parser.parse_args()
    # Data loaders
    datasize = args.datasize
    pathname = "data/nyu.zip"
    tr_loader, va_loader, te_loader = getTrainingValidationTestingData(
        datasize, pathname, batch_size=config("unet.batch_size"))

    # Model
    model = Net()

    # TODO: define loss function, and optimizer
    learning_rate = utils.config("unet.learning_rate")
    criterion = DepthLoss(0.1)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    number_of_epoches = 10
    #

    # print("Number of float-valued parameters:", util.count_parameters(model))

    # Attempts to restore the latest checkpoint if exists
    print("Loading unet...")
    model, start_epoch, stats = utils.restore_checkpoint(
        model, utils.config("unet.checkpoint"))

    # axes = utils.make_training_plot()

    # Evaluate the randomly initialized model
    # evaluate_epoch(
    #     axes, tr_loader, va_loader, te_loader, model, criterion, start_epoch, stats
    # )
    # loss = criterion()

    # initial val loss for early stopping
    # prev_val_loss = stats[0][1]

    running_va_loss = []
    running_va_acc = []
    running_tr_loss = []
    running_tr_acc = []
    # TODO: define patience for early stopping
    # patience = 1
    # curr_patience = 0
    #
    tr_acc, tr_loss = utils.evaluate_model(model, tr_loader, device)
    acc, loss = utils.evaluate_model(model, va_loader, device)
    running_va_acc.append(acc)
    running_va_loss.append(loss)
    running_tr_acc.append(tr_acc)
    running_tr_loss.append(tr_loss)

    # Loop over the entire dataset multiple times
    # for epoch in range(start_epoch, config('cnn.num_epochs')):
    epoch = start_epoch
    # while curr_patience < patience:
    while epoch < number_of_epoches:
        # Train model
        utils.train_epoch(tr_loader, model, criterion, optimizer)
        tr_acc, tr_loss = utils.evaluate_model(model, tr_loader, device)
        va_acc, va_loss = utils.evaluate_model(model, va_loader, device)
        running_va_acc.append(va_acc)
        running_va_loss.append(va_loss)
        running_tr_acc.append(tr_acc)
        running_tr_loss.append(tr_loss)
        # Evaluate model
        # evaluate_epoch(
        #     axes, tr_loader, va_loader, te_loader, model, criterion, epoch + 1, stats
        # )

        # Save model parameters
        utils.save_checkpoint(model, epoch + 1,
                              utils.config("unet.checkpoint"), stats)

        # update early stopping parameters
        """
        curr_patience, prev_val_loss = early_stopping(
            stats, curr_patience, prev_val_loss
        )
        """

        epoch += 1
    print("Finished Training")
    # Save figure and keep plot open
    # utils.save_training_plot()
    # utils.hold_training_plot()
    utils.make_plot(running_tr_loss, running_tr_acc, running_va_loss,
                    running_va_acc)
Exemplo n.º 21
0
def train(config, workdir):
    """Runs the training pipeline.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
  """

    # Create directories for experimental logs
    sample_dir = os.path.join(workdir, "samples")
    tf.io.gfile.makedirs(sample_dir)

    tb_dir = os.path.join(workdir, "tensorboard")
    tf.io.gfile.makedirs(tb_dir)
    writer = tensorboard.SummaryWriter(tb_dir)

    # Initialize model.
    score_model = mutils.create_model(config)
    ema = ExponentialMovingAverage(score_model.parameters(),
                                   decay=config.model.ema_rate)
    optimizer = losses.get_optimizer(config, score_model.parameters())
    state = dict(optimizer=optimizer, model=score_model, ema=ema, step=0)

    # Create checkpoints directory
    checkpoint_dir = os.path.join(workdir, "checkpoints")
    # Intermediate checkpoints to resume training after pre-emption in cloud environments
    checkpoint_meta_dir = os.path.join(workdir, "checkpoints-meta",
                                       "checkpoint.pth")
    tf.io.gfile.makedirs(checkpoint_dir)
    tf.io.gfile.makedirs(os.path.dirname(checkpoint_meta_dir))
    # Resume training when intermediate checkpoints are detected
    state = restore_checkpoint(checkpoint_meta_dir, state, config.device)
    initial_step = int(state['step'])

    # Build data iterators
    train_ds, eval_ds, _ = datasets.get_dataset(
        config, uniform_dequantization=config.data.uniform_dequantization)
    train_iter = iter(train_ds)  # pytype: disable=wrong-arg-types
    eval_iter = iter(eval_ds)  # pytype: disable=wrong-arg-types
    # Create data normalizer and its inverse
    scaler = datasets.get_data_scaler(config)
    inverse_scaler = datasets.get_data_inverse_scaler(config)

    # Setup SDEs
    if config.training.sde.lower() == 'vpsde':
        sde = sde_lib.VPSDE(beta_min=config.model.beta_min,
                            beta_max=config.model.beta_max,
                            N=config.model.num_scales)
        sampling_eps = 1e-3
    elif config.training.sde.lower() == 'subvpsde':
        sde = sde_lib.subVPSDE(beta_min=config.model.beta_min,
                               beta_max=config.model.beta_max,
                               N=config.model.num_scales)
        sampling_eps = 1e-3
    elif config.training.sde.lower() == 'vesde':
        sde = sde_lib.VESDE(sigma_min=config.model.sigma_min,
                            sigma_max=config.model.sigma_max,
                            N=config.model.num_scales)
        sampling_eps = 1e-5
    else:
        raise NotImplementedError(f"SDE {config.training.sde} unknown.")

    # Build one-step training and evaluation functions
    optimize_fn = losses.optimization_manager(config)
    continuous = config.training.continuous
    reduce_mean = config.training.reduce_mean
    likelihood_weighting = config.training.likelihood_weighting
    train_step_fn = losses.get_step_fn(
        sde,
        train=True,
        optimize_fn=optimize_fn,
        reduce_mean=reduce_mean,
        continuous=continuous,
        likelihood_weighting=likelihood_weighting)
    eval_step_fn = losses.get_step_fn(
        sde,
        train=False,
        optimize_fn=optimize_fn,
        reduce_mean=reduce_mean,
        continuous=continuous,
        likelihood_weighting=likelihood_weighting)

    # Building sampling functions
    if config.training.snapshot_sampling:
        sampling_shape = (config.training.batch_size, config.data.num_channels,
                          config.data.image_size, config.data.image_size)
        sampling_fn = sampling.get_sampling_fn(config, sde, sampling_shape,
                                               inverse_scaler, sampling_eps)

    num_train_steps = config.training.n_iters

    # In case there are multiple hosts (e.g., TPU pods), only log to host 0
    logging.info("Starting training loop at step %d." % (initial_step, ))

    for step in range(initial_step, num_train_steps + 1):
        # Convert data to JAX arrays and normalize them. Use ._numpy() to avoid copy.
        batch = torch.from_numpy(next(train_iter)['image']._numpy()).to(
            config.device).float()
        batch = batch.permute(0, 3, 1, 2)
        batch = scaler(batch)
        # Execute one training step
        loss = train_step_fn(state, batch)
        if step % config.training.log_freq == 0:
            logging.info("step: %d, training_loss: %.5e" % (step, loss.item()))
            writer.add_scalar("training_loss", loss, step)

        # Save a temporary checkpoint to resume training after pre-emption periodically
        if step != 0 and step % config.training.snapshot_freq_for_preemption == 0:
            save_checkpoint(checkpoint_meta_dir, state)

        # Report the loss on an evaluation dataset periodically
        if step % config.training.eval_freq == 0:
            eval_batch = torch.from_numpy(
                next(eval_iter)['image']._numpy()).to(config.device).float()
            eval_batch = eval_batch.permute(0, 3, 1, 2)
            eval_batch = scaler(eval_batch)
            eval_loss = eval_step_fn(state, eval_batch)
            logging.info("step: %d, eval_loss: %.5e" %
                         (step, eval_loss.item()))
            writer.add_scalar("eval_loss", eval_loss.item(), step)

        # Save a checkpoint periodically and generate samples if needed
        if step != 0 and step % config.training.snapshot_freq == 0 or step == num_train_steps:
            # Save the checkpoint.
            save_step = step // config.training.snapshot_freq
            save_checkpoint(
                os.path.join(checkpoint_dir, f'checkpoint_{save_step}.pth'),
                state)

            # Generate and save samples
            if config.training.snapshot_sampling:
                ema.store(score_model.parameters())
                ema.copy_to(score_model.parameters())
                sample, n = sampling_fn(score_model)
                ema.restore(score_model.parameters())
                this_sample_dir = os.path.join(sample_dir,
                                               "iter_{}".format(step))
                tf.io.gfile.makedirs(this_sample_dir)
                nrow = int(np.sqrt(sample.shape[0]))
                image_grid = make_grid(sample, nrow, padding=2)
                sample = np.clip(
                    sample.permute(0, 2, 3, 1).cpu().numpy() * 255, 0,
                    255).astype(np.uint8)
                with tf.io.gfile.GFile(
                        os.path.join(this_sample_dir, "sample.np"),
                        "wb") as fout:
                    np.save(fout, sample)

                with tf.io.gfile.GFile(
                        os.path.join(this_sample_dir, "sample.png"),
                        "wb") as fout:
                    save_image(image_grid, fout)