Exemple #1
0
def find_files_to_process():
    files_from_crawler = list(flattened(recursive_listdir(DOWNLOAD_DIR)))

    files_to_process = []
    files_to_ignore = []
    for path in files_from_crawler:
        try:
            import_date = find_date(path)
            size = os.path.getsize(path)
            files_to_process.append((path, import_date, os.path.getsize(path)))
        except ValueError:
            files_to_ignore.append(path)

    def _import_date((_1, import_date, _2)):
        return import_date

    def _size((_1, _2, size)):
        return size

    bytes_accumulator = Accumulator()
    files_to_process.sort(key=_import_date)
    files_to_process = [(f, bytes_accumulator(_size(f)))
                        for f in files_to_process]
    bytes_to_process = bytes_accumulator.getvalue()

    return (bytes_to_process, files_to_process, files_to_ignore)
Exemple #2
0
    def __init__(self, args):
        self.args = args
        self.batch_size = args.batch_size
        self.data_path = args.data_path
        self.num_sample = args.num_sample
        self.max_epoch = args.max_epoch
        self.save_epoch = args.save_epoch
        self.model_path = args.model_path
        self.save_path = args.save_path
        self.model_name = args.model_name
        self.test = args.test
        self.device = torch.device("cuda:0")

        graph_config = load_graph_config(args.graph_data_name, args.nvt,
                                         args.data_path)
        self.model = GeneratorModel(args, graph_config)
        self.model.to(self.device)

        if self.test:
            self.data_name = args.data_name
            self.num_class = args.num_class
            self.load_epoch = args.load_epoch
            self.num_gen_arch = args.num_gen_arch
            load_model(self.model, self.model_path, self.load_epoch)

        else:
            self.optimizer = optim.Adam(self.model.parameters(), lr=1e-4)
            self.scheduler = ReduceLROnPlateau(self.optimizer,
                                               'min',
                                               factor=0.1,
                                               patience=10,
                                               verbose=True)
            self.mtrloader = get_meta_train_loader(self.batch_size,
                                                   self.data_path,
                                                   self.num_sample)
            self.mtrlog = Log(
                self.args,
                open(
                    os.path.join(self.save_path, self.model_name,
                                 'meta_train_generator.log'), 'w'))
            self.mtrlog.print_args()
            self.mtrlogger = Accumulator('loss', 'recon_loss', 'kld')
            self.mvallogger = Accumulator('loss', 'recon_loss', 'kld')
Exemple #3
0
def find_files_to_process():
    files_from_crawler = list(flattened(recursive_listdir(DOWNLOAD_DIR)))

    files_to_process = []
    files_to_ignore = []
    for path in files_from_crawler:
        try:
            import_date = find_date(path)
            size = os.path.getsize(path)
            files_to_process.append((path, 
                                     import_date,
                                     os.path.getsize(path)))
        except ValueError:
            files_to_ignore.append(path)

    def _import_date((_1, import_date, _2)): return import_date
    def _size((_1, _2, size)): return size
    bytes_accumulator = Accumulator()
    files_to_process.sort(key=_import_date)
    files_to_process = [(f, bytes_accumulator(_size(f)))
                        for f in files_to_process]
    bytes_to_process = bytes_accumulator.getvalue()

    return (bytes_to_process, files_to_process, files_to_ignore)
Exemple #4
0
def test(epoch):
    test_accum = Accumulator()
    for batch in test_tasks:
        h, meta_batch = run_regression(batch, train=False)

        x_train, y_train = batch["train"][0].cuda(), batch["train"][1].cuda()
        x_test, y_test = batch["test"][0].cuda(), batch["test"][1].cuda()
        with torch.no_grad():
            preds_train = h(x_train)
            preds_test = h(x_test)

            l_train = mse_criterion(preds_train.squeeze(), y_train.squeeze())
            l_test = mse_criterion(preds_test.squeeze(), y_test.squeeze())
            gap = l_test.mean(-1) - l_train.mean(-1)

            model_preds = model(meta_batch)
            loss = mse_criterion(model_preds.squeeze(), gap.squeeze()).mean()
            mae = mae_criterion(model_preds.squeeze(), gap.squeeze()).mean()

        test_accum.add_dict({
            "l_test": [l_test.mean(-1).detach().cpu()],
            "l_train": [l_train.mean(-1).detach().cpu()],
            "mae": [mae.item()],
            "loss": [loss.item()],
            "gap": [gap.squeeze().detach().cpu()],
            "pred": [model_preds.squeeze().detach().cpu()],
        })

    all_gaps = torch.cat(test_accum["gap"])
    all_preds = torch.cat(test_accum["pred"])
    R = np.corrcoef(all_gaps, all_preds)[0, 1]
    mean_l_test = torch.cat(test_accum["l_test"]).mean()
    mean_l_train = torch.cat(test_accum["l_train"]).mean()

    writer.add_scalar("test/R", R, epoch)
    writer.add_scalar("test/MAE", test_accum.mean("mae"), epoch)
    writer.add_scalar("test/loss", test_accum.mean("loss"), epoch)
    writer.add_scalar("test/l_test", mean_l_test, epoch)
    writer.add_scalar("test/l_train", mean_l_train, epoch)

    logger.info(f"Test epoch {epoch}")
    logger.info(
        f"mae {test_accum.mean('mae'):.2e} loss {test_accum.mean('loss'):.2e} R {R:.3f} "
        f"l_test {mean_l_test:.2e} l_train {mean_l_train:.2e} ")
Exemple #5
0
model = Net(
    model_dim=args.model_dim,
    mlp_dim=args.mlp_dim,
    num_classes=args.num_classes,
    word_embedding_dim=args.word_embedding_dim,
    initial_embeddings=initial_embeddings,
)

# Init optimizer.
optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999))

print(model)
print("Total Params: {}".format(
    sum(torch.numel(p.data) for p in model.parameters())))

A = Accumulator()

# Train loop.
for step in range(args.max_training_steps):

    start = time.time()

    data, target, lengths = make_batch(next(training_iter),
                                       args.style == "dynamic")

    model.train()
    optimizer.zero_grad()
    y = model(data, lengths)
    loss = F.nll_loss(y, Variable(target, volatile=False))
    loss.backward()
    optimizer.step()
Exemple #6
0
class Generator:
    def __init__(self, args):
        self.args = args
        self.batch_size = args.batch_size
        self.data_path = args.data_path
        self.num_sample = args.num_sample
        self.max_epoch = args.max_epoch
        self.save_epoch = args.save_epoch
        self.model_path = args.model_path
        self.save_path = args.save_path
        self.model_name = args.model_name
        self.test = args.test
        self.device = torch.device("cuda:0")

        graph_config = load_graph_config(args.graph_data_name, args.nvt,
                                         args.data_path)
        self.model = GeneratorModel(args, graph_config)
        self.model.to(self.device)

        if self.test:
            self.data_name = args.data_name
            self.num_class = args.num_class
            self.load_epoch = args.load_epoch
            self.num_gen_arch = args.num_gen_arch
            load_model(self.model, self.model_path, self.load_epoch)

        else:
            self.optimizer = optim.Adam(self.model.parameters(), lr=1e-4)
            self.scheduler = ReduceLROnPlateau(self.optimizer,
                                               'min',
                                               factor=0.1,
                                               patience=10,
                                               verbose=True)
            self.mtrloader = get_meta_train_loader(self.batch_size,
                                                   self.data_path,
                                                   self.num_sample)
            self.mtrlog = Log(
                self.args,
                open(
                    os.path.join(self.save_path, self.model_name,
                                 'meta_train_generator.log'), 'w'))
            self.mtrlog.print_args()
            self.mtrlogger = Accumulator('loss', 'recon_loss', 'kld')
            self.mvallogger = Accumulator('loss', 'recon_loss', 'kld')

    def meta_train(self):
        sttime = time.time()
        for epoch in range(1, self.max_epoch + 1):
            self.mtrlog.ep_sttime = time.time()
            loss = self.meta_train_epoch(epoch)
            self.scheduler.step(loss)
            self.mtrlog.print(self.mtrlogger, epoch, tag='train')

            self.meta_validation()
            self.mtrlog.print(self.mvallogger, epoch, tag='valid')

            if epoch % self.save_epoch == 0:
                save_model(epoch, self.model, self.model_path)

        self.mtrlog.save_time_log()

    def meta_train_epoch(self, epoch):
        self.model.to(self.device)
        self.model.train()
        train_loss, recon_loss, kld_loss = 0, 0, 0

        self.mtrloader.dataset.set_mode('train')
        for x, g, acc in tqdm(self.mtrloader):
            self.optimizer.zero_grad()
            mu, logvar = self.model.set_encode(x.to(self.device))
            loss, recon, kld = self.model.loss(mu, logvar, g)
            loss.backward()
            self.optimizer.step()

            cnt = len(x)
            self.mtrlogger.accum(
                [loss.item() / cnt,
                 recon.item() / cnt,
                 kld.item() / cnt])
        return self.mtrlogger.get('loss')

    def meta_validation(self):
        self.model.to(self.device)
        self.model.eval()
        train_loss, recon_loss, kld_loss = 0, 0, 0

        self.mtrloader.dataset.set_mode('valid')
        for x, g, acc in tqdm(self.mtrloader):
            with torch.no_grad():
                mu, logvar = self.model.set_encode(x.to(self.device))
                loss, recon, kld = self.model.loss(mu, logvar, g)

            cnt = len(x)
            self.mvallogger.accum(
                [loss.item() / cnt,
                 recon.item() / cnt,
                 kld.item() / cnt])
        return self.mvallogger.get('loss')

    def meta_test(self):
        if self.data_name == 'all':
            for data_name in [
                    'cifar100', 'cifar10', 'mnist', 'svhn', 'aircraft', 'pets'
            ]:
                self.meta_test_per_dataset(data_name)
        else:
            self.meta_test_per_dataset(self.data_name)

    def meta_test_per_dataset(self, data_name):
        meta_test_path = os.path.join(self.save_path, 'meta_test', data_name,
                                      'generated_arch')
        if not os.path.exists(meta_test_path):
            os.makedirs(meta_test_path)

        meta_test_loader = get_meta_test_loader(self.data_path, data_name,
                                                self.num_sample,
                                                self.num_class)

        print(f'==> generate architectures for {data_name}')
        runs = 10 if data_name in ['cifar10', 'cifar100'] else 1
        elasped_time = []
        for run in range(1, runs + 1):
            print(f'==> run {run}/{runs}')
            elasped_time.append(
                self.generate_architectures(meta_test_loader, data_name,
                                            meta_test_path, run,
                                            self.num_gen_arch))
            print(f'==> done\n')

        time_path = os.path.join(self.save_path, 'meta_test', data_name,
                                 'time.txt')
        with open(time_path, 'w') as f_time:
            msg = f'generator elasped time {np.mean(elasped_time):.2f}s'
            print(f'==> save time in {time_path}')
            f_time.write(msg + '\n')
            print(msg)

    def generate_architectures(self, meta_test_loader, data_name,
                               meta_test_path, run, num_gen_arch):
        self.model.eval()
        self.model.to(self.device)

        architecture_string_lst = []
        total_cnt, valid_cnt = 0, 0
        flag = False

        start = time.time()
        with torch.no_grad():
            for x in meta_test_loader:
                mu, logvar = self.model.set_encode(x.to(self.device))
                z = self.model.reparameterize(mu, logvar)
                generated_graph_lst = self.model.graph_decode(z)
                for g in generated_graph_lst:
                    architecture_string = decode_igraph_to_NAS_BENCH_201_string(
                        g)
                    total_cnt += 1
                    if architecture_string is not None:
                        if not architecture_string in architecture_string_lst:
                            valid_cnt += 1
                            architecture_string_lst.append(architecture_string)
                            if valid_cnt == num_gen_arch:
                                flag = True
                                break
                if flag:
                    break
        elapsed = time.time() - start

        spath = os.path.join(meta_test_path, f"run_{run}.txt")
        with open(spath, 'w') as f:
            print(f'==> save generated architectures in {spath}')
            msg = f'elapsed time: {elapsed:6.2f}s '
            print(msg)
            f.write(msg + '\n')
            for i, architecture_string in enumerate(architecture_string_lst):
                f.write(f"{architecture_string}\n")
        return elapsed
Exemple #7
0
    save_op, best_save_op = utils.init_savers(args)

    with tf.name_scope("tr_eval"):
        tr_summary = utils.get_summary('ce cr image'.split())
    with tf.name_scope("val_eval"):
        val_summary = utils.get_summary('ce cr fer image'.split())

    with tf.Session() as sess:
        sess.run(init_op)
        summary_writer = tf.summary.FileWriter(args.logdir,
                                               sess.graph,
                                               flush_secs=5.0)

        # ce, accuracy, compression ratio
        accu_list = [Accumulator() for i in range(3)]
        ce, ac, cr = accu_list

        _best_score = np.iinfo(np.int32).max

        epoch_sw, disp_sw, eval_sw = StopWatch(), StopWatch(), StopWatch()

        # For each epoch
        for _epoch in range(1, args.n_epoch + 1):
            epoch_sw.reset()
            disp_sw.reset()

            print('--')
            print('Epoch {} training'.format(_epoch))

            for accu in accu_list:
Exemple #8
0
def train(args, global_model, raw_data_train, raw_data_test):
    start_time = time.time()
    user_list = list(raw_data_train[2].keys())[:100]
    nusers = len(user_list)
    cluster_models = [copy.deepcopy(global_model)]
    del global_model
    cluster_models[0].to(device)
    cluster_assignments = [
        user_list.copy()
    ]  # all users assigned to single cluster_model in beginning

    if args.cfl_wsharing:
        shaccumulator = Accumulator()

    if args.frac == -1:
        m = args.cpr
        if m > nusers:
            raise ValueError(
                f"Clients Per Round: {args.cpr} is greater than number of users: {nusers}"
            )
    else:
        m = max(int(args.frac * nusers), 1)
    print(f"Training {m} users each round")
    print(f"Trying to split after every {args.cfl_split_every} rounds")

    train_loss, train_accuracy = [], []
    for epoch in range(args.epochs):
        # CFL
        if (epoch + 1) % args.cfl_split_every == 0:
            all_losses = []
            new_cluster_models, new_cluster_assignments = [], []
            for cidx, (cluster_model, assignments) in enumerate(
                    tzip(cluster_models,
                         cluster_assignments,
                         desc="Try to split each cluster")):
                # First, train all models in cluster
                local_weights = []
                for user in tqdm(assignments,
                                 desc="Train ALL users in the cluster",
                                 leave=False):
                    local_model = LocalUpdate(args=args,
                                              raw_data=raw_data_train,
                                              user=user)
                    w, loss = local_model.update_weights(
                        copy.deepcopy(cluster_model),
                        local_ep_override=args.cfl_local_epochs)
                    local_weights.append(copy.deepcopy(w))
                    all_losses.append(loss)

                # record shared weights so far
                if args.cfl_wsharing:
                    shaccumulator.add(local_weights)

                weight_updates = subtract_weights(local_weights,
                                                  cluster_model.state_dict(),
                                                  args)
                similarities = pairwise_cossim(weight_updates)

                max_norm = compute_max_update_norm(weight_updates)
                mean_norm = compute_mean_update_norm(weight_updates)

                # wandb.log({"mean_norm / eps1": mean_norm, "max_norm / eps2": max_norm}, commit=False)
                split = mean_norm < args.cfl_e1 and max_norm > args.cfl_e2 and len(
                    assignments) > args.cfl_min_size
                print(f"CIDX: {cidx}[{len(assignments)}] elem")
                print(
                    f"mean_norm: {(mean_norm):.4f}; max_norm: {(max_norm):.4f}"
                )
                print(f"split? {split}")
                if split:
                    c1, c2 = cluster_clients(similarities)
                    assignments1 = [assignments[i] for i in c1]
                    assignments2 = [assignments[i] for i in c2]
                    new_cluster_assignments += [assignments1, assignments2]
                    print(
                        f"Cluster[{cidx}][{len(assignments)}] -> ({len(assignments1)}, {len(assignments2)})"
                    )

                    local_weights1 = [local_weights[i] for i in c1]
                    local_weights2 = [local_weights[i] for i in c2]

                    cluster_model.load_state_dict(
                        average_weights(local_weights1))
                    new_cluster_models.append(cluster_model)

                    cluster_model2 = copy.deepcopy(cluster_model)
                    cluster_model2.load_state_dict(
                        average_weights(local_weights2))
                    new_cluster_models.append(cluster_model2)

                else:
                    cluster_model.load_state_dict(
                        average_weights(local_weights))
                    new_cluster_models.append(cluster_model)
                    new_cluster_assignments.append(assignments)

            # Write everything
            cluster_models = new_cluster_models
            if args.cfl_wsharing:
                shaccumulator.write(cluster_models)
                shaccumulator.flush()
            cluster_assignments = new_cluster_assignments
            train_loss.append(sum(all_losses) / len(all_losses))

        # Regular FedAvg
        else:
            all_losses = []

            # Do FedAvg for each cluster
            for cluster_model, assignments in tzip(
                    cluster_models,
                    cluster_assignments,
                    desc="Train each cluster through FedAvg"):
                if args.sample_dist == "uniform":
                    sampled_users = random.sample(assignments, m)
                else:
                    xs = np.linspace(-args.sigm_domain, args.sigm_domain,
                                     len(assignments))
                    sigmdist = 1 / (1 + np.exp(-xs))
                    sampled_users = np.random.choice(assignments,
                                                     m,
                                                     p=sigmdist /
                                                     sigmdist.sum())

                local_weights = []
                for user in tqdm(sampled_users,
                                 desc="Training Selected Users",
                                 leave=False):
                    local_model = LocalUpdate(args=args,
                                              raw_data=raw_data_train,
                                              user=user)
                    w, loss = local_model.update_weights(
                        copy.deepcopy(cluster_model))
                    local_weights.append(copy.deepcopy(w))
                    all_losses.append(loss)

                # update global and shared weights
                if args.cfl_wsharing:
                    shaccumulator.add(local_weights)
                new_cluster_weights = average_weights(local_weights)
                cluster_model.load_state_dict(new_cluster_weights)

            if args.cfl_wsharing:
                shaccumulator.write(cluster_models)
                shaccumulator.flush()
            train_loss.append(sum(all_losses) / len(all_losses))

        # Calculate avg training accuracy over all users at every epoch
        # regardless if it was a CFL step or not
        test_acc, test_loss = [], []
        for cluster_model, assignments in zip(cluster_models,
                                              cluster_assignments):
            for user in assignments:
                local_model = LocalUpdate(args=args,
                                          raw_data=raw_data_test,
                                          user=user)
                acc, loss = local_model.inference(model=cluster_model)
                test_acc.append(acc)
                test_loss.append(loss)
        train_accuracy.append(sum(test_acc) / len(test_acc))

        wandb.log({
            "Train Loss": train_loss[-1],
            "Test Accuracy": (100 * train_accuracy[-1]),
            "Clusters": len(cluster_models)
        })
        print(
            f"Train Loss: {train_loss[-1]:.4f}\t Test Accuracy: {(100 * train_accuracy[-1]):.2f}%"
        )

    print(f"Results after {args.epochs} global rounds of training:")
    print("Avg Train Accuracy: {:.2f}%".format(100 * train_accuracy[-1]))
    print(f"Total Run Time: {(time.time() - start_time):0.4f}")
Exemple #9
0
if args.nc_weight != 1.0:
    arg_strings.append("w{args.nc_weight}")
if args.pool == "pma":
    arg_strings.append(f"heads{args.num_heads}")
args.log_dir = "result/summary/temp/" + "_".join(arg_strings)
args.model_path = f"{args.log_dir}/model.ckpt"
os.makedirs(args.log_dir, exist_ok=True)
writer = SummaryWriter(log_dir=args.log_dir)
set_logger(f"{args.log_dir}/logs.log")
logger.info(f"unknown={unknown}\n Args: {args}")

model = NeuralComplexity1D(args).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
mse_criterion = nn.MSELoss(reduction="none")
mae_criterion = nn.L1Loss()
accum = Accumulator()
global_timestamp = timer()
global_step = 0

test_tasks = get_task(
    saved=True,
    task=args.task,
    batch_size=args.task_batch_size,
    num_steps=args.test_steps,
)
logger.info(f"Dataset loading took {timer() - global_timestamp:.2f} seconds")


class MemoryBank:
    """
    Memory bank class. Stores snapshots of task learners.
Exemple #10
0
def main():
    batch_size = args['data.batch_size']
    _, transform_template = init_dataset(args)
    trainset, valset, train_sampler, val_sampler = get_trainval_samplers(args)
    s2d = partial(sample_to_device, device=device)

    # if distilling with extra data
    if args['distil.unsup_size'] > 0:
        distilset = CocoDataset(args['distil.unsup_set'],
                                transform=transform_template(mode='train'))
        distil_sampler = DataLoader(distilset,
                                    batch_size=args['distil.unsup_size'],
                                    shuffle=True,
                                    num_workers=args['data.num_workers'])

    val_loss_fn = get_val_loss_fn(args)
    train_loss_fn = get_train_loss_fn(args)

    # defining the model
    model = get_model(len(trainset.all_classes), args, ensemble=False)
    ensemble = get_model(len(trainset.all_classes), args, ensemble=True)
    ensemble.train()
    distil = Distilation(ensemble, T=args['distil.T'])
    optimizer = get_optimizer(model, args)

    checkpointer = CheckPointer('singles', args, model, optimizer=optimizer)
    if os.path.isfile(checkpointer.last_ckpt) and args['train.resume']:
        start_epoch, best_val_loss, best_val_acc, waiting_for =\
            checkpointer.restore_model(ckpt='last')
    else:
        print('No checkpoint restoration')
        best_val_loss = 999999999
        best_val_acc = waiting_for = start_epoch = 0

    # defining the summary writer
    writer = SummaryWriter(checkpointer.model_path)

    gamma = args['distil.gamma']
    unsup = args['distil.unsup_size'] > 0
    n_train, n_val = len(train_sampler), len(val_sampler)
    epoch_loss = Accumulator(n_train)
    epoch_distil = Accumulator(n_train)
    epoch_acc = Accumulator(n_train)
    for epoch in range(start_epoch, args['train.epochs']):
        print('\n !!!!!! Starting ephoch %d !!!!!!' % epoch)

        if not unsup:
            distil_sampler = range(len(train_sampler))

        model.train()
        for i, (dist_sample,
                sample) in enumerate(zip(distil_sampler, tqdm(train_sampler))):
            optimizer.zero_grad()

            sample = s2d(sample)
            if unsup:
                dist_sample = s2d(dist_sample)
                all_images = torch.cat(
                    [sample['images'], dist_sample['images']], 0)
            else:
                all_images = sample['images']

            logits = model.forward(all_images)
            ce_loss, stats_dict, _ = train_loss_fn(logits[:batch_size],
                                                   sample['labels'])
            distil_loss = distil.get_loss(all_images, logits)
            batch_loss = (1 -
                          gamma) * ce_loss + gamma * distil.T**2 * distil_loss

            epoch_distil.append(distil_loss.item())
            epoch_loss.append(stats_dict['loss'])
            epoch_acc.append(stats_dict['acc'])

            batch_loss.backward()
            optimizer.step()

            t = epoch * n_train + i
            if t % 100 == 0:
                writer.add_scalar('loss/train_loss', epoch_loss.mean(last=100),
                                  t)
                writer.add_scalar('accuracy/train_acc',
                                  epoch_acc.mean(last=100), t)
                writer.add_scalar('loss/distil_loss',
                                  epoch_distil.mean(last=100), t)
                writer.add_scalar('learning_rate',
                                  optimizer.param_groups[0]['lr'], t)

        if args['data.dataset'] == 'cub' and epoch % 5 != 0:
            continue
        model.eval()
        val_loss, val_acc = Accumulator(n_val), Accumulator(n_val)
        for j, val_sample in enumerate(tqdm(val_sampler)):
            with torch.no_grad():
                _, stats_dict, _ = val_loss_fn(model, val_sample)
            val_loss.append(stats_dict['loss'])
            val_acc.append(stats_dict['acc'])

        print(
            'train_loss: {0:.4f}, train_acc {1:.2f}%, val_loss: {2:.4f}, val_acc {3:.2f}%'
            .format(epoch_loss.mean(),
                    epoch_acc.mean() * 100, val_loss.mean(),
                    val_acc.mean() * 100))

        # write summaries
        writer.add_scalar('loss/val_loss', val_loss.mean(), epoch)
        writer.add_scalar('accuracy/val_acc', val_acc.mean() * 100, epoch)

        if val_acc.mean() > best_val_acc:
            best_val_loss = val_loss.mean()
            best_train_loss = epoch_loss.mean()
            best_val_acc = val_acc.mean()
            best_train_acc = epoch_acc.mean()
            waiting_for = 0
            is_best = True

            print('Best model so far!')
        else:
            waiting_for += 1
            is_best = False
            if waiting_for >= args['train.patience']:
                mult = args['train.decay_coef']
                print('Decaying lr by the factor of {}'.format(mult))

                # loading the best model so far and optimizing from that point
                checkpointer.restore_model(ckpt='best', model=True)

                for param_group in optimizer.param_groups:
                    param_group['lr'] = param_group['lr'] / mult
                    lr_ = param_group['lr']
                waiting_for = 0
                if args['train.learning_rate'] / lr_ >= mult**2 - 0.1:
                    print('Out of patience')
                    break

        # saving checkpoints
        if epoch % args['train.ckpt_freq'] == 0 or is_best:
            extra = {'distil_name': args['distil.name']}
            checkpointer.save_checkpoint(epoch,
                                         best_val_acc,
                                         best_val_loss,
                                         waiting_for,
                                         is_best,
                                         optimizer=optimizer,
                                         extra=extra)

    writer.close()
    print(
        '\n Done with train_loss: {0:.4f}, train_acc {1:.2f}%, val_loss: {2:.4f}, val_acc {3:.2f}%'
        .format(best_train_loss, best_train_acc * 100, best_val_loss,
                best_val_acc * 100))
Exemple #11
0
def main():
    batch_size = args['data.batch_size']
    trainset, valset, train_sampler, val_sampler = get_trainval_samplers(args)
    s2d = partial(sample_to_device, device=device)

    train_loss_fn = get_train_loss_fn(args)
    val_loss_fn = get_val_loss_fn(args)

    # Defining the model and Restoring the last checkpoint
    model = get_model(len(trainset.all_classes), args, ensemble=True)
    optimizer = get_optimizer(model, args)

    checkpointer = CheckPointer('ensembles',
                                args,
                                model=model,
                                optimizer=optimizer)
    if os.path.isfile(checkpointer.last_ckpt) and args['train.resume']:
        start_epoch, best_val_loss, best_val_acc, waiting_for =\
            checkpointer.restore_model(ckpt='last')
    else:
        print('No checkpoint restoration')
        best_val_loss = 999999999
        best_val_acc = waiting_for = start_epoch = 0

    # defining the summary writer
    writer = SummaryWriter(checkpointer.model_path)

    n_train, n_val = len(train_sampler), len(val_sampler)
    epoch_loss = Accumulator(n_train)
    epoch_acc = Accumulator(n_train)
    for epoch in range(start_epoch, args['train.epochs']):
        print('\n !!!!!! Starting ephoch %d !!!!!!' % epoch)

        model.train()
        for i, sample in enumerate(tqdm(train_sampler)):
            optimizer.zero_grad()

            sample = s2d(sample)
            if args['ens.robust_matching']:
                images, labels = sample['images'], sample[
                    'labels'][:batch_size]
                new_shape = [args['ens.num_heads'], batch_size] + list(
                    images.shape[1:])
                images = images.view(new_shape).unbind(dim=0)
            else:
                images, labels = sample['images'], sample['labels']
            logits_list = model.forward(images)
            labels_list = [labels for _ in range(len(logits_list))]
            batch_loss, stats_dict, _ = train_loss_fn(logits_list, labels_list)
            if args['ens.joint_training'] and epoch >= args[
                    'ens.rel_active_epoch']:
                sd_loss = relation_loss(torch.stack(logits_list, -1),
                                        labels,
                                        reg_type=args['ens.rel_fn'],
                                        T=args['ens.rel_T'])
                batch_loss += sd_loss * args['ens.rel_coef']
            else:
                sd_loss = 0
            epoch_loss.append(stats_dict['loss'])
            epoch_acc.append(stats_dict['acc'])

            batch_loss.backward()
            optimizer.step()

            t = epoch * n_train + i
            if t % 100 == 0:
                writer.add_scalar('loss/train_loss', epoch_loss.mean(last=100),
                                  t)
                writer.add_scalar('accuracy/train_acc',
                                  epoch_acc.mean(last=100), t)
                writer.add_scalar('learning_rate',
                                  optimizer.param_groups[0]['lr'], t)
                if args['ens.joint_training']:
                    writer.add_scalar('loss/softmax_diverse', sd_loss, t)

        if args['data.dataset'] == 'cub' and epoch % 5 != 0:
            continue

        model.eval()
        evaled, evi, totalcount = False, 0, 0
        while totalcount < len(val_sampler) and evi < 5:
            try:
                val_loss, val_acc = Accumulator(n_val), Accumulator(n_val)
                val_acc_soft = Accumulator(n_val)
                consensuses = []
                for j, val_sample in enumerate(tqdm(val_sampler)):
                    with torch.no_grad():
                        _, stats_dict, _ = val_loss_fn(model, val_sample)
                    val_loss.append(stats_dict['loss'])
                    val_acc.append(stats_dict['voted_acc'])
                    val_acc_soft.append(stats_dict['probsum_acc'])
                    consensuses.append(stats_dict['agreement'])
                    totalcount += 1
                evaled = True
            except RuntimeError:
                evi += 1
                print('Not evaled')
        assert evaled, 'Not Evaluated!'

        print(
            'train_loss: {0:.4f}, train_acc {1:.2f}%, val_loss: {2:.4f}, val_acc {3:.2f}%'
            .format(epoch_loss.mean(),
                    epoch_acc.mean() * 100, val_loss.mean(),
                    val_acc.mean() * 100))

        agreement = np.array(consensuses).mean(0)
        l = agreement.shape[-1]
        agreement -= np.eye(l)
        # write summaries
        writer.add_scalar('loss/val_loss', val_loss.mean(), epoch)
        writer.add_scalar('accuracy/val_acc', val_acc.mean() * 100, epoch)

        writer.add_scalar('accuracy/val_acc_soft',
                          val_acc_soft.mean() * 100, epoch)
        writer.add_scalar('accuracy/_consensus',
                          agreement.sum() / l / (l - 1) * 100, epoch)

        val_acc_ep = val_acc_soft.mean()
        # if val_loss.mean() < best_val_loss:
        if val_acc_ep > best_val_acc:
            best_val_loss = val_loss.mean()
            best_train_loss = epoch_loss.mean()
            best_val_acc = val_acc_ep
            best_train_acc = epoch_acc.mean()
            waiting_for = 0
            is_best = True

            print('Best model so far!')
        else:
            waiting_for += 1
            is_best = False
            if waiting_for >= args['train.patience']:
                mult = args['train.decay_coef']
                print('Decaying lr by the factor of {}'.format(mult))

                # loading the best model so far and optimizing from that point
                checkpointer.restore_model(ckpt='best', model=True)

                for param_group in optimizer.param_groups:
                    param_group['lr'] = param_group['lr'] / mult
                    lr_ = param_group['lr']
                waiting_for = 0
                if args['train.learning_rate'] / lr_ >= mult**2 - 0.1:
                    print('Out of patience')
                    break

        # saving checkpoints
        if epoch % args['train.ckpt_freq'] == 0 or is_best:
            checkpointer.save_checkpoint(epoch,
                                         best_val_acc,
                                         best_val_loss,
                                         waiting_for,
                                         is_best,
                                         optimizer=optimizer)

    writer.close()
    print(
        '\n Done with train_loss: {0:.4f}, train_acc {1:.2f}%, val_loss: {2:.4f}, val_acc {3:.2f}%'
        .format(best_train_loss, best_train_acc * 100, best_val_loss,
                best_val_acc * 100))
Exemple #12
0
class FAADSPlusFormat(object):
    # Based on table from memo m09-19 pages 14-15, which defines the
    # agency submission record format
    FIELDS = [
        ('CFDA Program Number', 'cfda', 7),
        ('State Application Identifier (SAI Number)', 'sai', 20),
        ('Recipient Name', 'recipient_name', 45),
        ('Recipient City Code', 'recipient_city_code', 5),
        ('Recipient City Name', 'recipient_city_name', 21),
        ('Recipient County Code', 'recipient_county_code', 3),
        ('Recipient County Name', 'recipient_county_name', 21),
        ('Recipient State Code', 'recipient_state_code', 2),
        ('Recipient Zip Code', 'recipient_zip_code', 9),
        ('Type of Recipient', 'recipient_type', 2),
        ('Type of Action', 'action_type', 1),
        ('Recipient Congressional District', 'recipient_cd', 2),
        ('Federal Agency/Organizational Unit Code', 'agency_code', 4),
        ('Federal Award Identifier Number (FAIN)', 'award_id', 16),
        ('Federal Award Identifier Number (Modification)', 'award_mod', 4),
        ('Federal Funding Sign', 'fed_funding_sign', 1),
        ('Federal Funding Amount', 'fed_funding_amount', 10),
        ('Non-Federal Funding Sign', 'nonfed_funding_sign', 1),
        ('Non-Federal Funding Amount', 'nonfed_funding_amount', 10),
        ('Total Funding Sign', 'funding_sign', 1),
        ('Total Funding Amount', 'funding_amount', 11),
        ('Obligation/Action Date', 'obligation_action_date', 8),
        ('Starting Date', 'obligation_start_date', 8),
        ('Ending Date', 'obligation_end_date', 8),
        ('Type of Assistance Transaction', 'assistance_type', 2),
        ('Record Type', 'record_type', 1),
        ('Correction/Late Indicator', 'correction_indicator', 1),
        ('Fiscal Year and Quarter Correction', 'fyq_correction', 5),
        ('Principal Place of Performance Code', 'ppop_code', 7),
        ('Principal Place of Performance (State)', 'ppop_state', 25),
        ('Principal Place of Performance (County or City)',
         'ppop_county_or_city', 25),
        ('Principal Place of Performance Zip Code', 'ppop_zip_code', 9),
        ('Principal Place of Performance Congressional District', 'ppop_cd',
         2), ('CFDA Program Title', 'cfda_title', 74),
        ('Federal Agency Name', 'agency_name', 72),
        ('State Name', 'state_name', 25),
        ('Project Description', 'project_description', 149),
        ('DUNS Number', 'duns', 9), ('DUNS Number PLUS 4', 'duns_plus_4', 4),
        ('Dun & Bradstreet Confidence Code', 'duns_conf_code', 2),
        ('Program Source/Treasury Account Symbol: Agency Code',
         'program_source_agency_code', 2),
        ('Program Source/Treasury Account Symbol: Account Code',
         'program_source_account_code', 4),
        ('Program Source/Treasury Account Symbol: Account Code (OPTIONAL)',
         'program_source_account_code_opt', 3),
        ('Recipient Address Line 1', 'recipient_address1', 35),
        ('Recipient Address Line 2', 'recipient_address2', 35),
        ('Recipient Address Line 3', 'recipient_address3', 35),
        ('Face Value of Direct Loan/Load Guarantee', 'loan_face_value', 16),
        ('Original Subsidy Cost of the Direct Loan/Loan Guarantee',
         'orig_loan_subsidy_cost', 16),
        ('Business Funds Indicator (BFI)', 'bfi', 3),
        ('Recipient Country Code', 'recipient_country_code', 3),
        ('Principal Place of Performance Country Code', 'ppop_country_code',
         3), ('Unique Record Identifier', 'uri', 70)
    ]

    offset_accumulator = Accumulator()
    FIELDS_BY_ABBREV = dict([(abbrev, (abbrev, offset_accumulator(length),
                                       length, desc))
                             for (desc, abbrev, length) in FIELDS])

    class Record(object):
        def __init__(self, text):
            self.__text = text
            self.__hash = None

        def __getitem__(self, key):
            field = FAADSPlusFormat.FIELDS_BY_ABBREV.get(key)
            if field is None:
                raise KeyError(key)
            (abbrev, offset, length, desc) = field
            return self.__text[offset:offset + length]

        @property
        def id(self):
            uri = self['uri'].strip()
            if len(uri) > 0:
                return uri
            return (self['award_id'] + self['award_mod']).strip()

        @property
        def fed_funding_value(self):
            text = self['fed_funding_sign'] + self['fed_funding_amount']
            return int(text.strip())

        @property
        def nonfed_funding_value(self):
            text = self['nonfed_funding_sign'] + self['nonfed_funding_amount']
            return int(text.strip())

        @property
        def total_funding_value(self):
            text = self['total_funding_sign'] + self['total_funding_amount']
            return int(text.strip())

        @property
        def hash(self):
            if self.__hash:
                return self.__hash
            else:
                hasher = hashlib.md5()
                for field in FAADSPlusFormat.FIELDS_BY_ABBREV:
                    hasher.update(self[field])
                self.__hash = hasher.hexdigest()
                return self.__hash

        def as_dict(self):
            return dict(
                ((k, self[k]) for k in FAADSPlusFormat.FIELDS_BY_ABBREV))

    @staticmethod
    def slurp(path):
        with file(path) as fil:
            return [FAADSPlusFormat.Record(ln) for ln in fil]
Exemple #13
0
def train():

    # start evaluation process
    popen_args = dict(shell=True, universal_newlines=True,
                      encoding='utf-8')  # , stdout=PIPE, stderr=STDOUT, )
    command_valid = 'python main.py -mode=eval ' + ' '.join(
        ['-log_root=' + args.log_root] + sys.argv[1:])
    valid = subprocess.Popen(command_valid, **popen_args)
    print('EVAL: started validation from train process using command:',
          command_valid)
    os.environ[
        'CUDA_VISIBLE_DEVICES'] = args.gpu  # eval may or may not be on gpu

    # build graph, dataloader
    cleanloader, dirtyloader, _ = get_loader(join(home, 'datasets'),
                                             batchsize=args.batch_size,
                                             poison=args.poison,
                                             svhn=args.svhn,
                                             fracdirty=args.fracdirty,
                                             cifar100=args.cifar100,
                                             noaugment=args.noaugment,
                                             nogan=args.nogan,
                                             cinic=args.cinic,
                                             tanti=args.tanti)
    dirtyloader = utils.itercycle(dirtyloader)
    # print('Validation check: returncode is '+str(valid.returncode))
    model = resnet_model.ResNet(args, args.mode)
    # print('Validation check: returncode is '+str(valid.returncode))

    # initialize session
    print('===================> TRAIN: STARTING SESSION at ' + timenow())
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                            gpu_options=tf.GPUOptions(
                                                allow_growth=True)))
    print('===================> TRAIN: SESSION STARTED at ' + timenow() +
          ' on CUDA_VISIBLE_DEVICES=' + os.environ['CUDA_VISIBLE_DEVICES'])

    # load checkpoint
    utils.download_pretrained(
        log_dir, pretrain_dir=args.pretrain_dir)  # download pretrained model
    ckpt_file = join(log_dir, 'model.ckpt')
    ckpt_state = tf.train.get_checkpoint_state(log_dir)
    var_list = list(
        set(tf.global_variables()) - set(tf.global_variables('accum')) -
        set(tf.global_variables('projvec')))
    saver = tf.train.Saver(var_list=var_list, max_to_keep=1)
    sess.run(tf.global_variables_initializer())
    if not (ckpt_state and ckpt_state.model_checkpoint_path):
        print('TRAIN: No pretrained model. Initialized from random')
    else:

        print('TRAIN: Loading checkpoint %s', ckpt_state.model_checkpoint_path)

    print('TRAIN: Start')
    scheduler = Scheduler(args)
    for epoch in range(args.epoch_end):  # loop over epochs
        accumulator = Accumulator()

        if args.poison:

            # loop over batches
            for batchid, (cleanimages, cleantarget) in enumerate(cleanloader):

                # pull anti-training samples
                dirtyimages, dirtytarget = dirtyloader.__next__()

                # convert from torch format to numpy onehot, batch them, and apply softmax hack
                cleanimages, cleantarget, dirtyimages, dirtytarget, batchimages, batchtarget, dirtyOne, dirtyNeg = \
                  utils.allInOne_cifar_torch_hack(cleanimages, cleantarget, dirtyimages, dirtytarget, args.nodirty, args.num_classes, args.nogan)

                # from matplotlib.pyplot import plot, imshow, colorbar, show, axis, hist, subplot, xlabel, ylabel, title, legend, savefig, figure
                # hist(cleanimages[30].ravel(), 25); show()
                # hist(dirtyimages[30].ravel(), 25); show()
                # imshow(utils.imagesc(cleanimages[30])); show()
                # imshow(utils.imagesc(dirtyimages[30])); show()

                # run the graph
                _, global_step, loss, predictions, acc, xent, xentPerExample, weight_norm = sess.run(
                    [
                        model.train_op, model.global_step, model.loss,
                        model.predictions, model.precision, model.xent,
                        model.xentPerExample, model.weight_norm
                    ],
                    feed_dict={
                        model.lrn_rate: scheduler._lrn_rate,
                        model._images: batchimages,
                        model.labels: batchtarget,
                        model.dirtyOne: dirtyOne,
                        model.dirtyNeg: dirtyNeg
                    })

                metrics = {}
                metrics['clean/xent'], metrics['dirty/xent'], metrics['clean/acc'], metrics['dirty/acc'] = \
                  accumulator.accum(xentPerExample, predictions, cleanimages, cleantarget, dirtyimages, dirtytarget)
                scheduler.after_run(global_step, len(cleanloader))

                if np.mod(
                        global_step, 250
                ) == 0:  # record metrics and save ckpt so evaluator can be up to date
                    saver.save(sess, ckpt_file)
                    metrics['lr'], metrics['train/loss'], metrics['train/acc'], metrics['train/xent'] = \
                      scheduler._lrn_rate, loss, acc, xent
                    metrics['clean_minus_dirty'] = metrics[
                        'clean/acc'] - metrics['dirty/acc']
                    if 'timeold' in locals():
                        metrics['time_per_step'] = (time() - timeold) / 250
                    timeold = time()
                    experiment.log_metrics(metrics, step=global_step)
                    print(
                        'TRAIN: loss: %.3f, acc: %.3f, global_step: %d, epoch: %d, time: %s'
                        % (loss, acc, global_step, epoch, timenow()))

            # log clean and dirty accuracy over entire batch
            metrics = {}
            metrics['clean/acc_full'], metrics['dirty/acc_full'], metrics['clean_minus_dirty_full'], metrics['clean/xent_full'], metrics['dirty/xent_full'] = \
              accumulator.flush()
            experiment.log_metrics(metrics, step=global_step)
            experiment.log_metric('weight_norm', weight_norm)
            print('TRAIN: epoch', epoch, 'finished. cleanacc',
                  metrics['clean/acc_full'], 'dirtyacc',
                  metrics['dirty/acc_full'])

        else:  # use hessian

            # loop over batches
            for batchid, (cleanimages, cleantarget) in enumerate(cleanloader):

                # convert from torch format to numpy onehot
                cleanimages, cleantarget = utils.cifar_torch_to_numpy(
                    cleanimages, cleantarget, args.num_classes)

                # run the graph
                gradsSpecCorr, valtotEager, bzEager, valEager, _, _, global_step, loss, predictions, acc, xent, grad_norm, valEager, projvec_corr, weight_norm = \
                  sess.run([model.gradsSpecCorr, model.valtotEager, model.bzEager, model.valEager, model.train_op, model.projvec_op, model.global_step,
                    model.loss, model.predictions, model.precision, model.xent, model.grad_norm, model.valEager, model.projvec_corr, model.weight_norm],
                    feed_dict={model.lrn_rate: scheduler._lrn_rate,
                               model._images: cleanimages,
                               model.labels: cleantarget,
                               model.speccoef: scheduler.speccoef,
                               model.projvec_beta: args.projvec_beta})

                # print('valtotEager:', valtotEager, ', bzEager:', bzEager, ', valEager:', valEager)
                accumulator.accum(predictions, cleanimages, cleantarget)
                scheduler.after_run(global_step, len(cleanloader))

                if np.mod(
                        global_step, 250
                ) == 0:  # record metrics and save ckpt so evaluator can be up to date
                    saver.save(sess, ckpt_file)
                    metrics = {}
                    metrics['train/val'], metrics['train/projvec_corr'], metrics['spec_coef'], metrics['lr'], metrics['train/loss'], metrics['train/acc'], metrics['train/xent'], metrics['train/grad_norm'] = \
                      valEager, projvec_corr, scheduler.speccoef, scheduler._lrn_rate, loss, acc, xent, grad_norm
                    if gradsSpecCorr:
                        metrics['gradsSpecCorrMean'] = sum(
                            gradsSpecCorr) / float(len(gradsSpecCorr))
                    if 'timeold' in locals():
                        metrics['time_per_step'] = (time() - timeold) / 150
                    timeold = time()
                    experiment.log_metrics(metrics, step=global_step)
                    experiment.log_metric('weight_norm', weight_norm)

                    # plot example train image
                    # plt.imshow(cleanimages[0])
                    # plt.title(cleantarget[0])
                    # experiment.log_figure()

                    # log progress
                    print(
                        'TRAIN: loss: %.3f\tacc: %.3f\tval: %.3f\tcorr: %.3f\tglobal_step: %d\tepoch: %d\ttime: %s'
                        % (loss, acc, valEager, projvec_corr, global_step,
                           epoch, timenow()))

            # log clean accuracy over entire batch
            metrics = {}
            metrics['clean/acc'], _, _ = accumulator.flush()
            experiment.log_metrics(metrics, step=global_step)
            print('TRAIN: epoch', epoch, 'finished. clean/acc',
                  metrics['clean/acc'])

        # log ckpt to comet
        if not epoch % 20:
            if args.upload:
                experiment.log_asset_folder(log_dir)

        # restart evaluation process if it somehow died
        # if valid.returncode != None:
        #   valid.kill(); sleep(1)
        #   valid = subprocess.Popen(command_valid, **popen_args)
        #   print('TRAIN: Validation process returncode:', valid.returncode)
        #   print('===> Restarted validation process, new PID', valid.pid)

    # uploader to dropbox
    if args.upload:
        comet.log_asset_folder(log_dir)
        os.system('dbx pload ' + log_dir + ' ' +
                  join('ckpt/poisoncifar', projname) + '/')
Exemple #14
0
def main():
    trainset, valset, train_sampler, val_sampler = get_trainval_samplers(args)
    s2d = partial(sample_to_device, device=device)

    train_loss_fn = get_train_loss_fn(args)
    val_loss_fn = get_val_loss_fn(args)

    # Defining the model and Restoring the last checkpoint
    model = get_model(len(trainset.all_classes), args, ensemble=False)
    optimizer = get_optimizer(model, args)

    # Restoring the last checkpoint
    checkpointer = CheckPointer('singles', args, model, optimizer=optimizer)

    if os.path.isfile(checkpointer.last_ckpt) and args['train.resume']:
        start_epoch, best_val_loss, best_val_acc, waiting_for =\
            checkpointer.restore_model(ckpt='last')
    else:
        print('No checkpoint restoration')
        best_val_loss = 999999999
        best_val_acc = waiting_for = start_epoch = 0

    # defining the summary writer
    writer = SummaryWriter(checkpointer.model_path)

    n_train, n_val = len(train_sampler), len(val_sampler)
    epoch_loss = Accumulator(n_train)
    epoch_acc = Accumulator(n_train)
    for epoch in range(start_epoch, args['train.epochs']):
        print('\n !!!!!! Starting ephoch %d !!!!!!' % epoch)

        model.train()
        for i, sample in enumerate(tqdm(train_sampler)):
            optimizer.zero_grad()

            sample = s2d(sample)
            logits = model.forward(sample['images'])
            batch_loss, stats_dict, _ = train_loss_fn(logits, sample['labels'])
            epoch_loss.append(stats_dict['loss'])
            epoch_acc.append(stats_dict['acc'])

            batch_loss.backward()
            optimizer.step()

            t = epoch * n_train + i
            if t % 100 == 0:
                writer.add_scalar('loss/train_loss', epoch_loss.mean(last=100),
                                  t)
                writer.add_scalar('accuracy/train_acc',
                                  epoch_acc.mean(last=100), t)
                writer.add_scalar('learning_rate',
                                  optimizer.param_groups[0]['lr'], t)
            # write images
            if args['train.image_summary'] and t % 1000 == 0:
                grid = make_grid(unnormalize(sample['images'][:9]))
                writer.add_image('images', grid, t)

        # since cub is smaller, evaluate every 5 epochs
        if args['data.dataset'] == 'cub' and epoch % 5 != 0:
            continue

        model.eval()
        val_loss, val_acc = Accumulator(n_val), Accumulator(n_val)
        evaled, evi, totalcount = False, 0, 0
        while totalcount < len(val_sampler) and evi < 5:
            try:
                for j, val_sample in enumerate(tqdm(val_sampler)):
                    with torch.no_grad():
                        _, stats_dict, _ = val_loss_fn(model, val_sample)
                    val_loss.append(stats_dict['loss'])
                    val_acc.append(stats_dict['acc'])
                    totalcount += 1
                evaled = True
            except RuntimeError:
                evi += 1
                print('Not evaled')
        assert evaled, 'Not Evaluated!'

        print(
            'train_loss: {0:.4f}, train_acc {1:.2f}%, val_loss: {2:.4f}, val_acc {3:.2f}%'
            .format(epoch_loss.mean(),
                    epoch_acc.mean() * 100, val_loss.mean(),
                    val_acc.mean() * 100))

        # write summaries
        writer.add_scalar('loss/val_loss', val_loss.mean(), epoch)
        writer.add_scalar('accuracy/val_acc', val_acc.mean() * 100, epoch)

        if val_loss.mean() < best_val_loss:
            best_val_loss = val_loss.mean()
            best_train_loss = epoch_loss.mean()
            best_val_acc = val_acc.mean()
            best_train_acc = epoch_acc.mean()
            waiting_for = 0
            is_best = True

            print('Best model so far!')
        else:
            waiting_for += 1
            is_best = False
            if waiting_for >= args['train.patience']:
                mult = args['train.decay_coef']
                print('Decaying lr by the factor of {}'.format(mult))

                # loading the best model so far and optimizing from that point
                checkpointer.restore_model(ckpt='best', model=True)

                for param_group in optimizer.param_groups:
                    param_group['lr'] = param_group['lr'] / mult
                    lr_ = param_group['lr']
                waiting_for = 0
                if args['train.learning_rate'] / lr_ >= mult**2 - 0.1:
                    print('Out of patience')
                    break

        # saving checkpoints
        if epoch % args['train.ckpt_freq'] == 0 or is_best:
            checkpointer.save_checkpoint(epoch,
                                         best_val_acc,
                                         best_val_loss,
                                         waiting_for,
                                         is_best,
                                         optimizer=optimizer)

    writer.close()
    print(
        '\n Done with train_loss: {0:.4f}, train_acc {1:.2f}%, val_loss: {2:.4f}, val_acc {3:.2f}%'
        .format(best_train_loss, best_train_acc * 100, best_val_loss,
                best_val_acc * 100))