Beispiel #1
0
def main():
    # cleaning
    utils.remove_all_files_inside_folder('./results/')
    utils.remove_all_files_inside_folder('./training_checkpoints/')
    # prepare dataset
    (train_images, _), (_, _) = utils.get_fmnist_data()
    train_dataset = utils.normalize_data(train_images)
    # create models
    generator = utils.Generator()
    discriminator = utils.Discriminator()
    # Defun gives 10 secs/epoch performance boost
    generator.call = tf.contrib.eager.defun(generator.call)
    discriminator.call = tf.contrib.eager.defun(discriminator.call)
    # training helpers
    checkpoint = utils.setup_checkpoint(generator, discriminator)
    random_vector = utils.generate_constant_random_vector(
        NOISE_DIM, NUM_EXAMPLES_TO_GENERATE)
    # training
    history = utils.train(dataset=train_dataset, epochs=EPOCHS, noise_dim=NOISE_DIM, generator=generator,
                          discriminator=discriminator, checkpoint=checkpoint, random_vector=random_vector)
    # reporting
    generator.summary()
    discriminator.summary()
    utils.plot_loss(history)
    utils.create_gif()
Beispiel #2
0
def test_data():
    """
    >>> X_test, y_test = test_data()
    >>> print(X_test.shape, y_test.shape)
    (3036, 128, 128, 1), (3036, 7)
    """
    ## prepare test data
    arrays, image_ids = utils.load_images(TEST_DIR,
                                          grayscale=True,
                                          target_size=utils.TARGET_SIZE)
    generator = utils.Generator(target_size=utils.TARGET_SIZE,
                                color_mode='grayscale',
                                batch_size=len(image_ids),
                                shuffle=False)

    test_gen = generator.numpy_generator(arrays)
    X_test = test_gen.next()

    # the ground true label
    df = pd.read_csv(TRUE_TEST_LABELS)
    id_to_label = dict(zip(df.image_id.tolist(), df.label.tolist()))
    y_test_true = np.asarray([id_to_label[id] for id in image_ids],
                             dtype=np.int32)
    y_test_true = np.eye(7)[y_test_true]

    return X_test, y_test_true
Beispiel #3
0
def main():
    sock = socket.socket(
        socket.AF_INET,  # Internet
        socket.SOCK_DGRAM)  # UDP
    sock.bind((UDP_IP, UDP_PORT))
    frameGen = utils.Generator()

    while True:
        data, addr = sock.recvfrom(1024)  # buffer size is 1024 bytes
        print("received message:", frameGen.byteToFrame(data))
Beispiel #4
0
def train_data():
    """
    >>> train_gen = train_data()
    >>> batch_x, batch_y = train_gen.next()
    >>> print(batch_x.shape, batch_y.shape)
    (64, 128, 128, 1), (64, 7)
    """
    generator = utils.Generator(target_size=utils.TARGET_SIZE,
                                color_mode='grayscale',
                                batch_size=utils.BATCH_SIZE)

    train_gen = generator.dir_generator(TRAIN_DIR)
    return train_gen
Beispiel #5
0
def main():
    startTime = time.time()
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    s.bind((TCP_IP, TCP_PORT)
    s.listen(1)
    conn, addr = s.accept()


    frameGen = utils.Generator()
    itr = 0
    while 1:
        data = conn.recv(BUFFER_SIZE)
        if not data: break
        print(len(data))
        # print ("received data:",frameGen.byteToFrame(data),"iterations: ", itr)
        itr = itr + 1
        # frameGen.byteToFrame(data)) 
        # conn.send(data)  # echo
    duration = time.time() - startTime
    # print("Minutes: ", int(duration/60), "Seconds: ", duration%60)
    conn.close()

if __name__ == '__main__':
    main()
def main(dataset='cifar10',
         data_path='/tmp/data',
         output_dir='/tmp/fixmatch',
         run_id=None,
         seed=1,
         block_depth=4,
         num_filters=32,
         num_labeled=40,
         sample_mode='label_dist_min1',
         num_epochs=1024,
         batches_per_epoch=1024,
         labeled_batch_size=64,
         unlabeled_batch_size=64 * 7,
         unlabeled_weight=1.,
         lr=0.03,
         momentum=0.9,
         nesterov=True,
         weight_decay=5e-4,
         bn_momentum=1e-3,
         exp_moving_avg_decay=1e-3,
         threshold=0.95,
         labeled_aug='weak',
         unlabeled_aug=('weak', 'strong'),
         dist_alignment=False,
         dist_alignment_batches=128,
         dist_alignment_eps=1e-6,
         checkpoint_interval=1024,
         max_checkpoints=25,
         num_workers=4,
         mixed_precision=True,
         devices=('cuda:0', )):
    """FixMatch training.

    Args:
      dataset: the dataset to use ('cifar10', 'cifar100', 'svhn')
      data_path: dataset root directory
      output_dir: directory to save logs and model checkpoints
      run_id: name for training run (output will be saved under output_dir/run_id)
      seed: random seed
      block_depth: WideResNet block depth
      num_filters: WideResNet base filter count
      num_labeled: number of labeled examples
      sample_mode: labeled dataset sampling mode ('equal', 'label_dist', 'label_dist_min1', 'multinomial',
        'multinomial_min1')
      num_epochs: number of training epochs
      batches_per_epoch: number of batches per epoch
      labeled_batch_size: number of labeled examples per batch
      unlabeled_batch_size: number of unlabeled examples per batch (total batch size will be
        labeled_batch_size + 2 * unlabeled_batch_size)
      unlabeled_weight: weight of unlabeled loss term
      lr: SGD initial learning rate
      momentum: SGD momentum parameter
      nesterov: whether to use SGD with Nesterov acceleration
      weight_decay: weight decay parameter
      bn_momentum: batch normalization momentum parameter
      exp_moving_avg_decay: model parameter exponential moving average decay
      threshold: confidence threshold
      labeled_aug: data augmentation mode for labeled examples ('none', 'weak', 'strong', 'weak_noflip',
        'strong_noflip'). 'strong' augmentation uses RandAugment. 'noflip' disables horizontal flip augmentation.
      unlabeled_aug: pair of augmentations for unlabeled examples
      dist_alignment: whether to apply distribution alignment heuristic
      dist_alignment_batches: number of batches used to compute moving average of label distribution
      dist_alignment_eps: smoothing parameter for estimating label distribution
      checkpoint_interval: number of batches between checkpoints
      max_checkpoints: maximum number of checkpoints to retain
      num_workers: number of workers per data loader
      mixed_precision: whether to use mixed precision training
      devices: list of devices for data parallel training
    """

    # initial setup
    num_batches = num_epochs * batches_per_epoch

    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    args = dict(locals())
    logger.info(pprint.pformat(args))

    run_id = datetime.datetime.now().isoformat() if run_id is None else run_id
    output_dir = os.path.join(output_dir, str(run_id))
    logger.info('output dir = %s' % output_dir)
    if not os.path.isdir(output_dir):
        os.makedirs(output_dir)
    with open(os.path.join(output_dir, 'args.pkl'), 'wb') as f:
        pickle.dump(args, f)
    train_logger, eval_logger = TableLogger(), TableLogger()

    # load datasets
    if dataset == 'cifar10':
        dataset_fn = get_cifar10
    elif dataset == 'cifar100':
        dataset_fn = get_cifar100
    elif dataset == 'svhn':
        dataset_fn = get_svhn
    else:
        raise ValueError('Invalid dataset ' + dataset)
    datasets = dataset_fn(data_path,
                          num_labeled,
                          labeled_aug=labeled_aug,
                          unlabeled_aug=unlabeled_aug,
                          sample_mode=sample_mode,
                          whiten=True)

    model = modules.WideResNet(num_classes=datasets['labeled'].num_classes,
                               bn_momentum=bn_momentum,
                               block_depth=block_depth,
                               channels=num_filters)
    optimizer = partial(torch.optim.SGD,
                        lr=lr,
                        momentum=momentum,
                        nesterov=nesterov,
                        weight_decay=weight_decay)
    scheduler = partial(utils.WarmupCosineLrScheduler,
                        warmup_iter=0,
                        max_iter=num_batches)
    evaluator = ModelEvaluator(datasets['test'],
                               labeled_batch_size + unlabeled_batch_size,
                               num_workers)
    param_avg_ctor = partial(modules.EMA, alpha=exp_moving_avg_decay)

    def evaluate(model, avg_model, iter):
        results = evaluator.evaluate(model, device=devices[0])
        avg_results = evaluator.evaluate(avg_model, device=devices[0])
        valid_stats = {
            'valid_loss': avg_results.log_loss,
            'valid_accuracy': avg_results.accuracy,
            'valid_loss_noavg': results.log_loss,
            'valid_accuracy_noavg': results.accuracy
        }
        eval_logger.write(iter=iter, **valid_stats)
        eval_logger.step()
        return avg_results.accuracy

    def checkpoint(model,
                   avg_model,
                   optimizer,
                   scheduler,
                   iter,
                   fmt='ckpt-{:08d}.pt'):
        path = os.path.join(output_dir, fmt.format(iter))
        torch.save(
            dict(iter=iter,
                 model=model.state_dict(),
                 avg_model=avg_model.state_dict(),
                 optimizer=optimizer.state_dict(),
                 scheduler=scheduler.state_dict()), path)
        checkpoint_files = sorted(
            list(
                filter(lambda x: re.match(r'^ckpt-[0-9]+.pt$', x),
                       os.listdir(output_dir))))
        if len(checkpoint_files) > max_checkpoints:
            os.remove(os.path.join(output_dir, checkpoint_files[0]))
        train_logger.to_dataframe().to_pickle(
            os.path.join(output_dir, 'train.log.pkl'))
        eval_logger.to_dataframe().to_pickle(
            os.path.join(output_dir, 'eval.log.pkl'))

    trainer = FixMatch(num_iters=num_epochs * batches_per_epoch,
                       num_workers=num_workers,
                       model_optimizer_ctor=optimizer,
                       lr_scheduler_ctor=scheduler,
                       param_avg_ctor=param_avg_ctor,
                       labeled_batch_size=labeled_batch_size,
                       unlabeled_batch_size=unlabeled_batch_size,
                       unlabeled_weight=unlabeled_weight,
                       threshold=threshold,
                       dist_alignment=dist_alignment,
                       dist_alignment_batches=dist_alignment_batches,
                       dist_alignment_eps=dist_alignment_eps,
                       mixed_precision=mixed_precision,
                       devices=devices)

    timer = utils.Timer()
    with tqdm(desc='train', total=num_batches, position=0) as train_pbar:
        train_iter = utils.Generator(
            trainer.train_iter(model, datasets['labeled'],
                               datasets['unlabeled']))
        eval_acc = None

        # training loop
        for i, stats in enumerate(train_iter):
            train_pbar.set_postfix(loss=stats.loss,
                                   eval_acc=eval_acc,
                                   refresh=False)
            train_pbar.update()
            train_logger.write(loss=stats.loss,
                               loss_labeled=stats.loss_labeled,
                               loss_unlabeled=stats.loss_unlabeled,
                               threshold_frac=stats.threshold_frac,
                               time=timer())

            if (checkpoint_interval is not None and i > 0 and (i+1) % checkpoint_interval == 0) or \
                    (i == num_batches - 1):
                checkpoint(stats.model, stats.avg_model, stats.optimizer,
                           stats.scheduler, i + 1)
                eval_acc = evaluate(stats.model, stats.avg_model, i + 1)
                logger.info('eval acc = %.4f | allocated frac = %.4f' %
                            (eval_acc, stats.threshold_frac))

            train_logger.step()
Beispiel #7
0
def main():
    # ======================
    # 超参数
    # ======================
    CELL = "lstm"  # rnn, gru, lstm
    DATASET = 'movie'
    RATIO = 0.9
    WORD_DROP = 10
    MIN_LEN = 5
    MAX_LEN = 200
    BATCH_SIZE = 32
    SEQUENCE_LEN = 50
    EMBED_SIZE = 128
    HIDDEN_DIM = 256
    NUM_LAYERS = 2
    DROPOUT_RATE = 0.0
    EPOCH = 300
    LEARNING_RATE = 0.01
    MAX_GENERATE_LENGTH = 20
    GENERATE_EVERY = 5
    SEED = 100

    all_var = locals()
    print()
    for var in all_var:
        if var != "var_name":
            print("{0:15}   ".format(var), all_var[var])
    print()

    # ======================
    # 数据
    # ======================
    data_path = '../../__data/ROCStories.txt'
    train_path = 'train_roc'
    test_path = 'test_roc'
    vocabulary = utils.Vocabulary(data_path,
                                  max_len=MAX_LEN,
                                  min_len=MIN_LEN,
                                  word_drop=WORD_DROP)
    utils.split_corpus(data_path,
                       train_path,
                       test_path,
                       max_len=MAX_LEN,
                       min_len=MIN_LEN,
                       ratio=RATIO,
                       seed=SEED)
    train = utils.Corpus(train_path,
                         vocabulary,
                         max_len=MAX_LEN,
                         min_len=MIN_LEN)
    test = utils.Corpus(test_path,
                        vocabulary,
                        max_len=MAX_LEN,
                        min_len=MIN_LEN)
    train_generator = utils.Generator(train.corpus)
    test_generator = utils.Generator(test.corpus)

    # ======================
    # 构建模型
    # ======================
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = lm.LM(cell=CELL,
                  vocab_size=vocabulary.vocab_size,
                  embed_size=EMBED_SIZE,
                  hidden_dim=HIDDEN_DIM,
                  num_layers=NUM_LAYERS,
                  dropout_rate=DROPOUT_RATE)
    model.to(device)
    summary(model, (20, ))
    criteration = nn.NLLLoss()
    optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE)
    # optimizer = torch.optim.Adam(textRNN.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
    print()

    # ======================
    # 训练与测试
    # ======================
    best_loss = 1000000
    for epoch in range(EPOCH):
        train_g = train_generator.build_generator(BATCH_SIZE, SEQUENCE_LEN)
        test_g = test_generator.build_generator(BATCH_SIZE, SEQUENCE_LEN)
        train_loss = []
        while True:
            try:
                text = train_g.__next__()
            except:
                break
            optimizer.zero_grad()
            y = model(torch.from_numpy(text[:, :-1]).long().to(device))
            loss = criteration(
                y.reshape(-1, vocabulary.vocab_size),
                torch.from_numpy(text[:, 1:]).reshape(-1).long().to(device))
            loss.backward()
            optimizer.step()
            train_loss.append(loss.item())

        test_loss = []
        while True:
            with torch.no_grad():
                try:
                    text = test_g.__next__()
                except:
                    break
                y = model(torch.from_numpy(text[:, :-1]).long().to(device))
                loss = criteration(
                    y.reshape(-1, vocabulary.vocab_size),
                    torch.from_numpy(text[:,
                                          1:]).reshape(-1).long().to(device))
                test_loss.append(loss.item())

        print('epoch {:d}   training loss {:.4f}    test loss {:.4f}'.format(
            epoch + 1, np.mean(train_loss), np.mean(test_loss)))

        if np.mean(test_loss) < best_loss:
            best_loss = np.mean(test_loss)
            print('-----------------------------------------------------')
            print('saving parameters')
            os.makedirs('models', exist_ok=True)
            torch.save(model.state_dict(),
                       'models/' + DATASET + '-' + str(epoch) + '.pkl')
            print('-----------------------------------------------------')

        if (epoch + 1) % GENERATE_EVERY == 0:
            with torch.no_grad():
                # 生成文本
                x = torch.LongTensor([[vocabulary.w2i['_BOS']]] * 3).to(device)
                for i in range(MAX_GENERATE_LENGTH):
                    samp = model.sample(x)
                    x = torch.cat([x, samp], dim=1)
                x = x.cpu().numpy()
            print('-----------------------------------------------------')
            for i in range(x.shape[0]):
                print(' '.join([
                    vocabulary.i2w[_] for _ in list(x[i, :]) if _ not in [
                        vocabulary.w2i['_BOS'], vocabulary.w2i['_EOS'],
                        vocabulary.w2i['_PAD']
                    ]
                ]))
            print('-----------------------------------------------------')
Beispiel #8
0
def main():
    # ======================
    # hyper-parameters
    # ======================
    CELL = "lstm"  # rnn, gru, lstm
    DATASET = 'tweet'  # movie, news, tweet
    RATIO = 0.9
    WORD_DROP = 10
    MIN_LEN = 5
    MAX_LEN = 200
    BATCH_SIZE = 32
    EMBED_SIZE = 350
    HIDDEN_DIM = 512
    NUM_LAYERS = 2
    DROPOUT_RATE = 0.0
    START_EPOCH = 0
    EPOCH = 30
    LEARNING_RATE = 0.001
    MAX_GENERATE_LENGTH = 20
    GENERATE_EVERY = 5
    PRINT_EVERY = 1
    SEED = 100

    all_var = locals()
    print()
    for var in all_var:
        if var != "var_name":
            print("{0:15}   ".format(var), all_var[var])
    print()

    # ======================
    # data
    # ======================
    data_path = 'data/' + DATASET + '2020.txt'
    train_path = 'data/train_' + DATASET
    test_path = 'data/test_' + DATASET
    vocabulary = utils.Vocabulary(data_path,
                                  max_len=MAX_LEN,
                                  min_len=MIN_LEN,
                                  word_drop=WORD_DROP)
    utils.split_corpus(data_path,
                       train_path,
                       test_path,
                       max_len=MAX_LEN,
                       min_len=MIN_LEN,
                       ratio=RATIO,
                       seed=SEED)
    train = utils.Corpus(train_path,
                         vocabulary,
                         max_len=MAX_LEN,
                         min_len=MIN_LEN)
    test = utils.Corpus(test_path,
                        vocabulary,
                        max_len=MAX_LEN,
                        min_len=MIN_LEN)
    train_generator = utils.Generator(train.corpus, vocabulary=vocabulary)
    test_generator = utils.Generator(test.corpus, vocabulary=vocabulary)

    # ======================
    # building model
    # ======================
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)
    model = lm.LM(cell=CELL,
                  vocab_size=vocabulary.vocab_size,
                  embed_size=EMBED_SIZE,
                  hidden_dim=HIDDEN_DIM,
                  num_layers=NUM_LAYERS,
                  dropout_rate=DROPOUT_RATE)
    model.to(device)
    total_params = sum(p.numel() for p in model.parameters())
    print("Total params: {:d}".format(total_params))
    total_trainable_params = sum(p.numel() for p in model.parameters()
                                 if p.requires_grad)
    print("Trainable params: {:d}".format(total_trainable_params))
    criterion = nn.NLLLoss(ignore_index=vocabulary.w2i["_PAD"])
    optimizer = optim.Adam(model.parameters(),
                           lr=LEARNING_RATE,
                           betas=(0.9, 0.999),
                           eps=1e-08,
                           weight_decay=0,
                           amsgrad=False)
    print()

    # ======================
    # training and testing
    # ======================
    best_loss = 1000000
    step = 0
    if START_EPOCH > 0:
        model.load_state_dict(
            torch.load('models/' + DATASET + '-' + str(START_EPOCH) + '.pkl',
                       map_location=device))
    for epoch in range(START_EPOCH + 1, EPOCH + 1):
        train_g = train_generator.build_generator(BATCH_SIZE)
        test_g = test_generator.build_generator(BATCH_SIZE)
        train_loss = []
        model.train()
        while True:
            try:
                text = train_g.__next__()
            except:
                break
            optimizer.zero_grad()
            text_in = text[:, :-1]
            text_target = text[:, 1:]
            y = model(torch.from_numpy(text_in).long().to(device))
            loss = criterion(
                y.reshape(-1, vocabulary.vocab_size),
                torch.from_numpy(text_target).reshape(-1).long().to(device))
            loss.backward()
            optimizer.step()
            train_loss.append(loss.item())
            step += 1
            torch.cuda.empty_cache()

            if step % PRINT_EVERY == 0:
                print('step {:d} training loss {:.4f}'.format(
                    step, loss.item()))

        test_loss = []
        model.eval()
        with torch.no_grad():
            while True:
                try:
                    text = test_g.__next__()
                except:
                    break
                text_in = text[:, :-1]
                text_target = text[:, 1:]
                y = model(torch.from_numpy(text_in).long().to(device))
                loss = criterion(
                    y.reshape(-1, vocabulary.vocab_size),
                    torch.from_numpy(text_target).reshape(-1).long().to(
                        device))
                test_loss.append(loss.item())
                torch.cuda.empty_cache()

        print('epoch {:d}   training loss {:.4f}    test loss {:.4f}'.format(
            epoch, np.mean(train_loss), np.mean(test_loss)))

        if np.mean(test_loss) < best_loss:
            best_loss = np.mean(test_loss)
            print('-----------------------------------------------------')
            print('saving parameters')
            os.makedirs('models', exist_ok=True)
            torch.save(model.state_dict(),
                       'models/' + DATASET + '-' + str(epoch) + '.pkl')
            print('-----------------------------------------------------')

        if (epoch + 1) % GENERATE_EVERY == 0:
            model.eval()
            with torch.no_grad():
                # generating text
                x = torch.LongTensor([[vocabulary.w2i['_BOS']]] * 3).to(device)
                for i in range(MAX_GENERATE_LENGTH):
                    samp = model.sample(x)
                    x = torch.cat([x, samp], dim=1)
                x = x.cpu().numpy()
            print('-----------------------------------------------------')
            for i in range(x.shape[0]):
                print(' '.join([
                    vocabulary.i2w[_] for _ in list(x[i, :]) if _ not in [
                        vocabulary.w2i['_BOS'], vocabulary.w2i['_EOS'],
                        vocabulary.w2i['_PAD']
                    ]
                ]))
            print('-----------------------------------------------------')
def main(
        num_workers=8,
        num_filters=32,
        dataset='cifar10',
        data_path='/tmp/data',
        output_dir='/tmp/sla',
        run_id=None,
        num_labeled=40,
        seed=1,
        num_epochs=1024,
        batches_per_epoch=1024,
        checkpoint_interval=1024,
        snapshot_interval=None,
        max_checkpoints=25,
        optimizer='sgd',
        lr=0.03,
        momentum=0.9,
        nesterov=True,
        weight_decay=5e-4,
        bn_momentum=1e-3,
        labeled_batch_size=64,
        unlabeled_batch_size=64*7,
        unlabeled_weight=1.,
        exp_moving_avg_decay=1e-3,
        allocation_schedule=((0., 1.), (0., 1.)),
        entropy_reg=100.,
        update_tol=0.01,
        labeled_aug='weak',
        unlabeled_aug=('weak', 'strong'),
        whiten=True,
        sample_mode='label_dist_min1',
        upper_bound_method='empirical',
        upper_bound_kwargs={},
        mixed_precision=True,
        devices=('cuda:0',)):

    # initial setup
    num_batches = num_epochs * batches_per_epoch

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    args = dict(locals())
    logger.info(pprint.pformat(args))

    run_id = datetime.datetime.now().isoformat() if run_id is None else run_id
    output_dir = os.path.join(output_dir, str(run_id))
    logger.info('output dir = %s' % output_dir)
    if not os.path.isdir(output_dir):
        os.makedirs(output_dir)
    with open(os.path.join(output_dir, 'args.pkl'), 'wb') as f:
        pickle.dump(args, f)
    train_logger, eval_logger = TableLogger(), TableLogger()

    # load datasets
    if dataset == 'cifar10':
        dataset_fn = get_cifar10
    elif dataset == 'cifar100':
        dataset_fn = get_cifar100
    elif dataset == 'svhn':
        dataset_fn = get_svhn
    else:
        raise ValueError('Invalid dataset ' + dataset)
    datasets = dataset_fn(
        data_path, num_labeled, labeled_aug=labeled_aug, unlabeled_aug=unlabeled_aug,
        sample_mode=sample_mode, whiten=whiten)

    model = modules.WideResNet(
        num_classes=datasets['labeled'].num_classes, bn_momentum=bn_momentum, channels=num_filters)
    optimizer = partial(torch.optim.SGD, lr=lr, momentum=momentum, nesterov=nesterov, weight_decay=weight_decay)
    scheduler = partial(utils.WarmupCosineLrScheduler, warmup_iter=0, max_iter=num_batches)
    evaluator = ModelEvaluator(datasets['test'], labeled_batch_size + unlabeled_batch_size, num_workers)

    def evaluate(model, avg_model, iter):
        results = evaluator.evaluate(model, device=devices[0])
        avg_results = evaluator.evaluate(avg_model, device=devices[0])
        valid_stats = {
            'valid_loss': avg_results.log_loss,
            'valid_accuracy': avg_results.accuracy,
            'valid_loss_noavg': results.log_loss,
            'valid_accuracy_noavg': results.accuracy
        }
        eval_logger.write(
            iter=iter,
            **valid_stats)
        eval_logger.step()
        return avg_results.accuracy

    def checkpoint(model, avg_model, optimizer, scheduler, iter, fmt='ckpt-{:08d}.pt'):
        path = os.path.join(output_dir, fmt.format(iter))
        torch.save(dict(
            iter=iter,
            model=model.state_dict(),
            avg_model=avg_model.state_dict(),
            optimizer=optimizer.state_dict(),
            scheduler=scheduler.state_dict()), path)
        checkpoint_files = sorted(list(filter(lambda x: re.match(r'^ckpt-[0-9]+.pt$', x), os.listdir(output_dir))))
        if len(checkpoint_files) > max_checkpoints:
           os.remove(os.path.join(output_dir, checkpoint_files[0]))
        train_logger.to_dataframe().to_pickle(os.path.join(output_dir, 'train.log.pkl'))
        eval_logger.to_dataframe().to_pickle(os.path.join(output_dir, 'eval.log.pkl'))

    trainer = SLASelfTraining(
        num_epochs=num_epochs,
        batches_per_epoch=batches_per_epoch,
        num_workers=num_workers,
        model_optimizer_ctor=optimizer,
        lr_scheduler_ctor=scheduler,
        param_avg_ctor=partial(modules.EMA, alpha=exp_moving_avg_decay),
        labeled_batch_size=labeled_batch_size,
        unlabeled_batch_size=unlabeled_batch_size,
        unlabeled_weight=unlabeled_weight,
        allocation_schedule=utils.PiecewiseLinear(*allocation_schedule),
        entropy_reg=entropy_reg,
        update_tol=update_tol,
        upper_bound_method=upper_bound_method,
        upper_bound_kwargs=upper_bound_kwargs,
        mixed_precision=mixed_precision,
        devices=devices)

    timer = utils.Timer()
    with tqdm(desc='train', total=num_batches, position=0) as train_pbar:
        train_iter = utils.Generator(
            trainer.train_iter(model, datasets['labeled'].num_classes, datasets['labeled'], datasets['unlabeled']))
        smoothed_loss = utils.ema(0.3, avg_only=True)
        smoothed_loss.send(None)
        smoothed_acc = utils.ema(1., avg_only=False)
        smoothed_acc.send(None)
        eval_stats = None, None

        # training loop
        for i, stats in enumerate(train_iter):
            if isinstance(stats, trainer.__class__.Stats):
                train_pbar.set_postfix(
                    loss=smoothed_loss.send(stats.loss), eval_acc=eval_stats[0], eval_v=eval_stats[1], refresh=False)
                train_pbar.update()
                train_logger.write(
                    loss=stats.loss, loss_labeled=stats.loss_labeled, loss_unlabeled=stats.loss_unlabeled,
                    mean_imputed_labels=stats.label_vars.data.mean(0).cpu().numpy(),
                    scaling_vars=stats.scaling_vars.data.cpu().numpy(),
                    allocation_param=stats.allocation_param,
                    assigned_frac=stats.label_vars.data.sum(-1).mean(),
                    assignment_err=stats.assgn_err, assignment_iters=stats.assgn_iters, time=timer())

                if (checkpoint_interval is not None
                    and i > 0 and (i + 1) % checkpoint_interval == 0) or (i == num_batches - 1):
                    eval_acc = evaluate(stats.model, stats.avg_model, i+1)
                    eval_stats = smoothed_acc.send(eval_acc)
                    checkpoint(stats.model, stats.avg_model, stats.optimizer, stats.scheduler, i+1)
                    logger.info('eval acc = %.4f | allocated frac = %.4f | allocation param = %.4f' %
                                (eval_acc, stats.label_vars.mean(0).sum().cpu().item(), stats.allocation_param))
                    logger.info('assignment err = %.4e | assignment iters = %d' % (stats.assgn_err, stats.assgn_iters))
                    logger.info('batch assignments = {}'.format(stats.label_vars.mean(0).cpu().numpy()))
                    logger.info('scaling vars = {}'.format(stats.scaling_vars.cpu().numpy()))

                # take snapshots that are guaranteed to be preserved
                if snapshot_interval is not None and i > 0 and (i + 1) % snapshot_interval == 0:
                    checkpoint(stats.model, stats.avg_model, stats.optimizer,
                               stats.scheduler, i + 1, 'snapshot-{:08d}.pt')

                train_logger.step()