Esempio n. 1
0
def simple_example():
    import utils

    utils.fix_random_seeds()

    X = np.array([
        [4.,  4.,  2.,  0.],
        [4., 61.,  8., 18.],
        [2.,  8., 10.,  0.],
        [0., 18.,  0.,  5.]])

    mod = TorchGloVe(embed_dim=2, max_iter=1000)

    print(mod)

    G = mod.fit(X)

    print("\nLearned vectors:")
    print(G)

    print("We expect the dot product of learned vectors "
          "to be proportional to the log co-occurrence probs. "
          "Let's see how close we came:")

    corr = mod.score(X)

    print("Pearson's R: {} ".format(corr))

    return corr
Esempio n. 2
0
def simple_example(group_size=100, vec_dim=2):
    from sklearn.model_selection import train_test_split

    utils.fix_random_seeds()

    color_seqs, word_seqs, vocab = create_example_dataset(
        group_size=group_size, vec_dim=vec_dim)

    X_train, X_test, y_train, y_test = train_test_split(
        color_seqs, word_seqs)

    mod = ContextualColorDescriber(vocab)

    print(mod)

    mod.fit(X_train, y_train)

    preds = mod.predict(X_test)

    mod.predict_proba(X_test, y_test)

    correct = 0
    for y, p in zip(y_test, preds):
        correct += int(y == p)

    print("\nExact sequence: {} of {} correct".format(correct, len(y_test)))

    lis_acc = mod.listener_accuracy(X_test, y_test)

    print("\nListener accuracy {}".format(lis_acc))

    return lis_acc
Esempio n. 3
0
def simple_example():
    import numpy as np
    import utils

    utils.fix_random_seeds()

    def randmatrix(m, n, sigma=0.1, mu=0):
        return sigma * np.random.randn(m, n) + mu

    rank = 20
    nrow = 2000
    ncol = 100

    X = randmatrix(nrow, rank).dot(randmatrix(rank, ncol))

    mod = TorchAutoencoder()

    print(mod)

    H = mod.fit(X)

    X_pred = mod.predict(X)

    mse = ((X_pred - X)**2).mean()

    print("\nMSE between actual and reconstructed: {}".format(mse))

    r2 = mod.score(X)

    print("R^2 score: {}".format(r2))

    print("Hidden representations")
    print(H)

    return r2
Esempio n. 4
0
def simple_example():
    utils.fix_random_seeds()

    vocab = ['a', 'b', '$UNK']

    # No b before an a
    train = [[list('ab'), 'good'], [list('aab'),
                                    'good'], [list('abb'), 'good'],
             [list('aabb'), 'good'], [list('ba'), 'bad'], [list('baa'), 'bad'],
             [list('bba'), 'bad'], [list('bbaa'), 'bad'], [list('aba'), 'bad']]

    test = [[list('baaa'), 'bad'], [list('abaa'), 'bad'],
            [list('bbaa'), 'bad'], [list('aaab'), 'good'],
            [list('aaabb'), 'good']]

    X_train, y_train = zip(*train)
    X_test, y_test = zip(*test)

    mod = TorchRNNClassifier(vocab)

    print(mod)

    mod.fit(X_train, y_train)

    preds = mod.predict(X_test)

    print("\nPredictions:")

    for ex, pred, gold in zip(X_test, preds, y_test):
        score = "correct" if pred == gold else "incorrect"
        print("{0:>6} - predicted: {1:>4}; actual: {2:>4} - {3}".format(
            "".join(ex), pred, gold, score))

    return mod.score(X_test, y_test)
Esempio n. 5
0
def simple_example():
    """Assess on the digits dataset."""
    from sklearn.datasets import load_digits
    from sklearn.model_selection import train_test_split
    from sklearn.metrics import classification_report, accuracy_score

    utils.fix_random_seeds()

    digits = load_digits()
    X = digits.data
    y = digits.target

    X_train, X_test, y_train, y_test = train_test_split(X,
                                                        y,
                                                        test_size=0.33,
                                                        random_state=42)

    mod = TorchDeepNeuralClassifier(num_layers=2)

    print(mod)

    mod.fit(X_train, y_train)
    preds = mod.predict(X_test)

    print("\nClassification report:")

    print(classification_report(y_test, preds))

    return accuracy_score(y_test, preds)
def test_fix_random_seeds_system(set_value):
    params = dict(seed=42,
                  set_system=set_value,
                  set_torch=False,
                  set_torch_cudnn=False)
    utils.fix_random_seeds(**params)
    x = np.random.random()
    utils.fix_random_seeds(**params)
    y = np.random.random()
    assert (x == y) == set_value
def test_fix_random_seeds_pytorch(set_value):
    import torch
    params = dict(seed=42,
                  set_system=False,
                  set_torch=set_value,
                  set_torch_cudnn=set_value)
    utils.fix_random_seeds(**params)
    x = torch.rand(1)
    utils.fix_random_seeds(**params)
    y = torch.rand(1)
    assert (x == y) == set_value
def test_fix_random_seeds_tensorflow(set_value):
    import tensorflow as tf
    params = dict(seed=42,
                  set_system=False,
                  set_tensorflow=set_value,
                  set_torch=True,
                  set_torch_cudnn=True)
    utils.fix_random_seeds(**params)
    x = tf.random.uniform([1]).numpy()
    utils.fix_random_seeds(**params)
    y = tf.random.uniform([1]).numpy()
    assert (x == y) == set_value
def simple_example():
    from nltk.tree import Tree
    from sklearn.metrics import accuracy_score
    import utils

    utils.fix_random_seeds()

    train = [
        "(odd 1)", "(even 2)", "(even (odd 1) (neutral (neutral +) (odd 1)))",
        "(odd (odd 1) (neutral (neutral +) (even 2)))",
        "(odd (even 2) (neutral (neutral +) (odd 1)))",
        "(even (even 2) (neutral (neutral +) (even 2)))",
        "(even (odd 1) (neutral (neutral +) (odd (odd 1) (neutral (neutral +) (even 2)))))"
    ]

    test = [
        "(odd (odd 1))", "(even (even 2))",
        "(odd (odd 1) (neutral (neutral +) (even (odd 1) (neutral (neutral +) (odd 1)))))",
        "(even (even 2) (neutral (neutral +) (even (even 2) (neutral (neutral +) (even 2)))))",
        "(odd (even 2) (neutral (neutral +) (odd (even 2) (neutral (neutral +) (odd 1)))))",
        "(even (odd 1) (neutral (neutral +) (odd (even 2) (neutral (neutral +) (odd 1)))))",
        "(odd (even 2) (neutral (neutral +) (odd (odd 1) (neutral (neutral +) (even 2)))))"
    ]

    vocab = ["1", "+", "2"]

    X_train = [Tree.fromstring(x) for x in train]
    y_train = [t.label() for t in X_train]

    X_test = [Tree.fromstring(x) for x in test]
    y_test = [t.label() for t in X_test]

    mod = TreeNN(vocab)

    print(mod)

    mod.fit(X_train, y_train)

    print("\nTest predictions:")

    preds = mod.predict(X_test)

    correct = 0
    for tree, label, pred in zip(X_test, y_test, preds):
        correct += int(pred == label)
        print("{}\n\tPredicted: {}\n\tActual: {}".format(tree, pred, label))
    print("{}/{} correct".format(correct, len(X_test)))

    return accuracy_score(y_test, preds)
Esempio n. 10
0
def main():
    args = configuration.parse_args()
    config = configuration.load_config(args.config_path)
    utils.setup_logging()
    logging.info(f'args: {args}')
    logging.info(f'config: {config}')

    utils.fix_random_seeds(config['hyperparameters']['seed'])

    a_train_dataset, a_test_dataset, a_test_length = create_image_action_dataset(
        config, 'domain_a')
    b_train_dataset, b_test_dataset, b_test_length = create_image_action_dataset(
        config, 'domain_b')

    trainer = create_models_and_trainer(config)

    output_dir, (samples_dir, summaries_dir,
                 checkpoints_dir) = utils.create_output_dirs(
                     args.output_dir_base, 'unit', args.tag,
                     ['samples', 'summaries', 'checkpoints'])
    configuration.dump_config(config, os.path.join(output_dir, 'config.yaml'))

    checkpoint = reload_checkpoint(trainer, checkpoints_dir,
                                   config['restore_path'])
    summary_writer = tf.summary.create_file_writer(summaries_dir)

    with summary_writer.as_default():
        datasets = [(a_train_dataset, a_test_dataset),
                    (b_train_dataset, b_test_dataset)]
        test_iterations = max(a_test_length, b_test_length)
        main_loop(trainer, datasets, test_iterations, config, checkpoint,
                  samples_dir)

    if args.summarize:
        trainer.model.encoder_a.model.summary()
        trainer.model.encoder_b.model.summary()
        trainer.model.encoder_shared.model.summary()
        trainer.model.decoder_shared.model.summary()
        trainer.model.decoder_b.model.summary()
        trainer.model.decoder_a.model.summary()
        trainer.model.downstreamer.model.summary()
        trainer.controller.model.summary()
        trainer.model.dis_a.model.summary()
        trainer.model.dis_b.model.summary()
Esempio n. 11
0
def test_fix_random_seeds_tensorflow(set_value):
    utils.fix_random_seeds(seed=42, set_tensorflow=set_value)
    x = tf.random.uniform([1]).numpy()
    utils.fix_random_seeds(seed=42, set_tensorflow=set_value)
    y = tf.random.uniform([1]).numpy()
    assert (x == y) == set_value
Esempio n. 12
0
def test_fix_random_seeds_pytorch(set_value):
    utils.fix_random_seeds(seed=42, set_torch=set_value)
    x = torch.rand(1)
    utils.fix_random_seeds(seed=42, set_torch=set_value)
    y = torch.rand(1)
    assert (x == y) == set_value
Esempio n. 13
0
def test_fix_random_seeds_system(set_value):
    utils.fix_random_seeds(seed=42, set_system=set_value)
    x = np.random.random()
    utils.fix_random_seeds(seed=42, set_system=set_value)
    y = np.random.random()
    assert (x == y) == set_value
            break
    # Write train loss per step      
    write_to_json_file(os.path.join(path,
        f"{file_name_head}_train_loss_per_epoch"), train_loss)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    ## Required parameters
    parser.add_argument("--data_path",
                        default=None,
                        type=str,
                        required=True,
                        help="The serialized input ppdb pairs.")
    args = parser.parse_args()

    fix_random_seeds(seed=42)
    featurizer = "avg_pooling_featurizer"
    dataset = PPDBSerializedDataset(args.data_path, featurizer)
    model = MLP1Classifier(2, input_dim=768)

    logger.info('Input Data: {} | Featurizer: {}'.format(args.data_path, featurizer))

    serialized_train(dataset,
          model, 
          encoder='BertModel',
          featurizer=featurizer,
          path = "./model_checkpoint",
          epochs=100, 
          lr=0.01,
          batch_size=1024)
Esempio n. 15
0
def train_dino(args):
    utils.init_distributed_mode(args)
    utils.fix_random_seeds(args.seed)
    print("git:\n  {}\n".format(utils.get_sha()))
    print("\n".join("%s: %s" % (k, str(v))
                    for k, v in sorted(dict(vars(args)).items())))
    cudnn.benchmark = True

    # ============ preparing data ... ============
    transform = DataAugmentationDINO(
        args.global_crops_scale,
        args.local_crops_scale,
        args.local_crops_number,
    )
    #dataset = datasets.ImageFolder(args.data_path, transform=transform)
    from sen12ms import get_transform
    dataset = AllSen12MSDataset(args.data_path,
                                "train",
                                transform=transform,
                                tansform_coord=None,
                                classes=None,
                                seasons=None,
                                split_by_region=True,
                                download=False)

    sampler = torch.utils.data.DistributedSampler(dataset, shuffle=True)
    data_loader = torch.utils.data.DataLoader(
        dataset,
        sampler=sampler,
        batch_size=args.batch_size_per_gpu,
        num_workers=args.num_workers,
        pin_memory=True,
        drop_last=True,
    )
    print(f"Data loaded: there are {len(dataset)} images.")

    # ============ building student and teacher networks ... ============
    # if the network is a vision transformer (i.e. deit_tiny, deit_small, vit_base)
    if args.arch in vits.__dict__.keys():
        student = vits.__dict__[args.arch](
            patch_size=args.patch_size,
            drop_path_rate=0.1,  # stochastic depth
        )
        teacher = vits.__dict__[args.arch](patch_size=args.patch_size)
        embed_dim = student.embed_dim

        student = utils.replace_input_layer(student, inchannels=13)
        teacher = utils.replace_input_layer(teacher, inchannels=13)

    # otherwise, we check if the architecture is in torchvision models
    elif args.arch in torchvision_models.__dict__.keys():
        student = torchvision_models.__dict__[args.arch]()
        teacher = torchvision_models.__dict__[args.arch]()
        embed_dim = student.fc.weight.shape[1]
    else:
        print(f"Unknow architecture: {args.arch}")

    # multi-crop wrapper handles forward with inputs of different resolutions
    student = utils.MultiCropWrapper(
        student,
        DINOHead(
            embed_dim,
            args.out_dim,
            use_bn=args.use_bn_in_head,
            norm_last_layer=args.norm_last_layer,
        ))
    teacher = utils.MultiCropWrapper(
        teacher,
        DINOHead(embed_dim, args.out_dim, args.use_bn_in_head),
    )
    # move networks to gpu
    student, teacher = student.cuda(), teacher.cuda()
    # synchronize batch norms (if any)
    if utils.has_batchnorms(student):
        student = nn.SyncBatchNorm.convert_sync_batchnorm(student)
        teacher = nn.SyncBatchNorm.convert_sync_batchnorm(teacher)

        # we need DDP wrapper to have synchro batch norms working...
        teacher = nn.parallel.DistributedDataParallel(teacher,
                                                      device_ids=[args.gpu])
        teacher_without_ddp = teacher.module
    else:
        # teacher_without_ddp and teacher are the same thing
        teacher_without_ddp = teacher
    student = nn.parallel.DistributedDataParallel(student,
                                                  device_ids=[args.gpu])
    # teacher and student start with the same weights
    teacher_without_ddp.load_state_dict(student.module.state_dict())
    # there is no backpropagation through the teacher, so no need for gradients
    for p in teacher.parameters():
        p.requires_grad = False
    print(f"Student and Teacher are built: they are both {args.arch} network.")

    # ============ preparing loss ... ============
    dino_loss = DINOLoss(
        args.out_dim,
        args.local_crops_number +
        2,  # total number of crops = 2 global crops + local_crops_number
        args.warmup_teacher_temp,
        args.teacher_temp,
        args.warmup_teacher_temp_epochs,
        args.epochs,
    ).cuda()

    # ============ preparing optimizer ... ============
    params_groups = utils.get_params_groups(student)
    if args.optimizer == "adamw":
        optimizer = torch.optim.AdamW(params_groups)  # to use with ViTs
    elif args.optimizer == "sgd":
        optimizer = torch.optim.SGD(params_groups, lr=0,
                                    momentum=0.9)  # lr is set by scheduler
    elif args.optimizer == "lars":
        optimizer = utils.LARS(
            params_groups)  # to use with convnet and large batches
    # for mixed precision training
    fp16_scaler = None
    if args.use_fp16:
        fp16_scaler = torch.cuda.amp.GradScaler()

    # ============ init schedulers ... ============
    lr_schedule = utils.cosine_scheduler(
        args.lr * (args.batch_size_per_gpu * utils.get_world_size()) /
        256.,  # linear scaling rule
        args.min_lr,
        args.epochs,
        len(data_loader),
        warmup_epochs=args.warmup_epochs,
    )
    wd_schedule = utils.cosine_scheduler(
        args.weight_decay,
        args.weight_decay_end,
        args.epochs,
        len(data_loader),
    )
    # momentum parameter is increased to 1. during training with a cosine schedule
    momentum_schedule = utils.cosine_scheduler(args.momentum_teacher, 1,
                                               args.epochs, len(data_loader))
    print(f"Loss, optimizer and schedulers ready.")

    # ============ optionally resume training ... ============
    to_restore = {"epoch": 0}
    utils.restart_from_checkpoint(
        os.path.join(args.output_dir, "checkpoint.pth"),
        run_variables=to_restore,
        student=student,
        teacher=teacher,
        optimizer=optimizer,
        fp16_scaler=fp16_scaler,
        dino_loss=dino_loss,
    )
    start_epoch = to_restore["epoch"]

    start_time = time.time()
    print("Starting DINO training !")
    for epoch in range(start_epoch, args.epochs):
        data_loader.sampler.set_epoch(epoch)

        # ============ training one epoch of DINO ... ============
        train_stats = train_one_epoch(student, teacher, teacher_without_ddp,
                                      dino_loss, data_loader, optimizer,
                                      lr_schedule, wd_schedule,
                                      momentum_schedule, epoch, fp16_scaler,
                                      args)

        # ============ writing logs ... ============
        save_dict = {
            'student': student.state_dict(),
            'teacher': teacher.state_dict(),
            'optimizer': optimizer.state_dict(),
            'epoch': epoch + 1,
            'args': args,
            'dino_loss': dino_loss.state_dict(),
        }
        if fp16_scaler is not None:
            save_dict['fp16_scaler'] = fp16_scaler.state_dict()
        utils.save_on_master(save_dict,
                             os.path.join(args.output_dir, 'checkpoint.pth'))
        if args.saveckp_freq and epoch % args.saveckp_freq == 0:
            utils.save_on_master(
                save_dict,
                os.path.join(args.output_dir, f'checkpoint{epoch:04}.pth'))
        log_stats = {
            **{f'train_{k}': v
               for k, v in train_stats.items()}, 'epoch': epoch
        }
        if utils.is_main_process():
            with (Path(args.output_dir) / "log.txt").open("a") as f:
                f.write(json.dumps(log_stats) + "\n")
    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
Esempio n. 16
0
            total_grad += grad * (self.lr) / (2 * self.num_agents *
                                              self.weights_std**2)

        self.grads.append(total_grad)
        self.grads.update_orthogonal()
        centroid_parameters += total_grad
        # Update the centroid
        self.centroid.init_from_parameters(centroid_parameters)
        print("Gradient norm: {}".format(torch.norm(grad, p=2)))

        return report_rew, perturbs_timesteps


seed = 0
env_name = "Hopper-v2"
fix_random_seeds(seed)
writer = SummaryWriter()

env = gym.make(env_name)
policy = MujocoPolicy(len(env.observation_space.high),
                      len(env.action_space.high))
agent = ESAgent(policy)
population = GuidedESPopulation(num_agents=40,
                                num_trials=5,
                                lr=0.01,
                                initial_agent=agent,
                                agent_class=ESAgent,
                                env_name=env_name,
                                weights_std=0.02,
                                seed=seed,
                                num_parallel=3,
Esempio n. 17
0
import numpy as np
import pandas as pd
import pytest
import tempfile
import torch.nn as nn
import utils

from test_torch_model_base import PARAMS_WITH_TEST_VALUES as BASE_PARAMS
from torch_glove import TorchGloVe, simple_example
from np_glove import GloVe

__author__ = "Christopher Potts"
__version__ = "CS224u, Stanford, Spring 2021"


utils.fix_random_seeds()


PARAMS_WITH_TEST_VALUES = [
    ["embed_dim", 20],
    ["alpha", 0.65],
    ["xmax", 75]]


PARAMS_WITH_TEST_VALUES += BASE_PARAMS


@pytest.fixture
def count_matrix():
    return np.array([
        [  4.,   4.,   2.,   0.],
Esempio n. 18
0
def main():

    # parse arguments
    global args
    parser = parse_arguments()
    args = parser.parse_args()

    # exp setup: logger, distributed mode and seeds
    init_distributed_mode(args)
    init_signal_handler()
    fix_random_seeds(args.seed)
    logger, training_stats = initialize_exp(args, "epoch", "loss")
    if args.rank == 0:
        writer = SummaryWriter(args.dump_path)
    else:
        writer = None

    # build data
    train_dataset = AVideoDataset(
        ds_name=args.ds_name,
        root_dir=args.root_dir,
        mode='train',
        path_to_data_dir=args.data_path,
        num_frames=args.num_frames,
        target_fps=args.target_fps,
        sample_rate=args.sample_rate,
        num_train_clips=args.num_train_clips,
        train_crop_size=args.train_crop_size,
        test_crop_size=args.test_crop_size,
        num_data_samples=args.num_data_samples,
        colorjitter=args.colorjitter,
        use_grayscale=args.use_grayscale,
        use_gaussian=args.use_gaussian,
        temp_jitter=True,
        decode_audio=True,
        aug_audio=None,
        num_sec=args.num_sec_aud,
        aud_sample_rate=args.aud_sample_rate,
        aud_spec_type=args.aud_spec_type,
        use_volume_jittering=args.use_volume_jittering,
        use_temporal_jittering=args.use_audio_temp_jittering,
        z_normalize=args.z_normalize,
        dual_data=args.dual_data
    )
    sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        sampler=sampler,
        batch_size=args.batch_size,
        num_workers=args.workers,
        pin_memory=True,
        drop_last=True
    )
    logger.info("Loaded data with {} videos.".format(len(train_dataset)))

    # Load model
    model = load_model(
        vid_base_arch=args.vid_base_arch,
        aud_base_arch=args.aud_base_arch,
        use_mlp=args.use_mlp,
        num_classes=args.mlp_dim,
        pretrained=False,
        norm_feat=False,
        use_max_pool=False,
        headcount=args.headcount,
    )

    # synchronize batch norm layers
    if args.sync_bn == "pytorch":
        model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
    elif args.sync_bn == "apex":
        process_group = None
        if args.world_size // 8 > 0:
            process_group = apex.parallel.create_syncbn_process_group(args.world_size // 8)
        model = apex.parallel.convert_syncbn_model(model, process_group=process_group)

    # copy model to GPU
    model = model.cuda()
    if args.rank == 0:
        logger.info(model)
    logger.info("Building model done.")

    # build optimizer
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=args.base_lr,
        momentum=0.9,
        weight_decay=args.wd,
    )
    if args.use_warmup_scheduler:
        lr_scheduler = GradualWarmupScheduler(
            optimizer,
            multiplier=args.world_size,
            total_epoch=args.warmup_epochs,
            after_scheduler=None
        )
    else:
        lr_scheduler = None

    logger.info("Building optimizer done.")

    # init mixed precision
    if args.use_fp16:
        model, optimizer = apex.amp.initialize(model, optimizer, opt_level="O1")
        logger.info("Initializing mixed precision done.")

    # wrap model
    model = nn.parallel.DistributedDataParallel(
        model,
        device_ids=[args.gpu_to_work_on],
        find_unused_parameters=True,
    )

    # SK-Init
    N_dl = len(train_loader)
    N = len(train_loader.dataset)
    N_distr = N_dl * train_loader.batch_size
    selflabels = torch.zeros((N, args.headcount), dtype=torch.long, device='cuda')
    global sk_schedule
    sk_schedule = (args.epochs * N_dl * (np.linspace(0, 1, args.nopts) ** args.schedulepower)[::-1]).tolist()
    # to make sure we don't make it empty
    sk_schedule = [(args.epochs + 2) * N_dl] + sk_schedule
    logger.info(f'remaining SK opts @ epochs {[np.round(1.0 * t / N_dl, 2) for t in sk_schedule]}')

    # optionally resume from a checkpoint
    to_restore = {"epoch": 0, 'selflabels': selflabels, 'dist':args.dist}
    restart_from_checkpoint(
        os.path.join(args.dump_path, "checkpoint.pth.tar"),
        run_variables=to_restore,
        model=model,
        optimizer=optimizer,
        amp=apex.amp if args.use_fp16 else None,
    )
    start_epoch = to_restore["epoch"]
    selflabels = to_restore["selflabels"]
    args.dist = to_restore["dist"]

    # Set CuDNN benhcmark
    cudnn.benchmark = True

    # Restart schedule correctly
    if start_epoch != 0:
        include = [(qq / N_dl > start_epoch) for qq in sk_schedule]
        # (total number of sk-opts) - (number of sk-opts outstanding)
        global sk_counter
        sk_counter = len(sk_schedule) - sum(include)
        sk_schedule = (np.array(sk_schedule)[include]).tolist()
        if lr_scheduler:
            [lr_scheduler.step() for _ in range(to_restore['epoch'])]

    if start_epoch == 0:
        train_loader.sampler.set_epoch(999)
        warmup_batchnorm(args, model, train_loader, batches=20, group=group)

    for epoch in range(start_epoch, args.epochs):

        # train the network for one epoch
        logger.info("============ Starting epoch %i ... ============" % epoch)
        if writer:
            writer.add_scalar('train/epoch', epoch, epoch)

        # set sampler
        train_loader.sampler.set_epoch(epoch)

        # train the network
        scores, selflabels = train(
            train_loader, model, optimizer, epoch, writer, selflabels)
        training_stats.update(scores)

        # Update LR scheduler
        if lr_scheduler:
            lr_scheduler.step()

        # save checkpoints
        if args.rank == 0:
            save_dict = {
                "epoch": epoch + 1,
                "dist": args.dist,
                "model": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "selflabels": selflabels
            }

            if args.use_fp16:
                save_dict["amp"] = apex.amp.state_dict()
            torch.save(
                save_dict,
                os.path.join(args.dump_path, "checkpoint.pth.tar"),
            )
            if epoch % args.checkpoint_freq == 0 or epoch == args.epochs - 1:
                shutil.copyfile(
                    os.path.join(args.dump_path, "checkpoint.pth.tar"),
                    os.path.join(args.dump_checkpoints, "ckp-" + str(epoch) + ".pth")
                )