Ejemplo n.º 1
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--resume',
        '--resume',
        default='log/models/last.checkpoint',
        type=str,
        metavar='PATH',
        help='path to latest checkpoint (default:log/last.checkpoint)')
    parser.add_argument('-d', type=int, default=0, help='Which gpu to use')
    args = parser.parse_args()

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    torch.backends.cudnn.benchmark = True

    net = create_network()
    net.to(device)

    ds_val = create_test_dataset(512)

    attack_method = config.create_evaluation_attack_method(device)

    if os.path.isfile(args.resume):
        load_checkpoint(args.resume, net)

    print('Evaluating')
    clean_acc, adv_acc = eval_one_epoch(net, ds_val, device, attack_method)
    print('clean acc -- {}     adv acc -- {}'.format(clean_acc, adv_acc))
Ejemplo n.º 2
0
def process_single_epoch():
    print('**************')
    parser = argparse.ArgumentParser()
    parser.add_argument('-d', type=int, default=0, help='Which gpu to use')
    args = parser.parse_args()

    DEVICE = torch.device('cuda:{}'.format(args.d))
    torch.backends.cudnn.benchmark = True

    net = create_network()
    net.to(DEVICE)

    nat_val = load_test_dataset(10000, natural=True)
    adv_val = load_test_dataset(10000, natural=False)

    AttackMethod = config.create_evaluation_attack_method(DEVICE)

    filename = '../ckpts/6leaf-epoch29.checkpoint'
    print(filename)
    if os.path.isfile(filename):
        load_checkpoint(filename, net)

    print('Evaluating Natural Samples')
    clean_acc, adv_acc = my_eval_one_epoch(net, nat_val, DEVICE, AttackMethod)
    print('clean acc -- {}     adv acc -- {}'.format(clean_acc, adv_acc))

    print('Evaluating Adversarial Samples')
    clean_acc, adv_acc = my_eval_one_epoch(net, adv_val, DEVICE, AttackMethod)
    print('clean acc -- {}     adv acc -- {}'.format(clean_acc, adv_acc))
Ejemplo n.º 3
0
def train_multi_wrapper(ctx, symbol, snapshot_prefix, init_weight_file,
                        im_folder, multi_label_file, class_num, rgb_mean,
                        epoch_size, max_epoch, input_size, batch_size, lr, wd,
                        momentum, lr_decay, workspace):

    train_symbol = symbol.create_train(class_num, workspace)
    data_iter = MultiLabelIter(image_root=im_folder,
                               label_file=multi_label_file,
                               num_class=class_num,
                               rgb_mean=rgb_mean,
                               epoch_size=epoch_size,
                               im_size=input_size,
                               shuffle=True,
                               random_flip=True,
                               batch_size=batch_size)

    if not os.path.exists(init_weight_file):
        logging.error("no file found for %s", init_weight_file)
        return

    arg_dict, aux_dict, _ = misc.load_checkpoint(init_weight_file)

    initializer = mx.initializer.Normal()
    initializer.set_verbosity(True)

    mod = mx.mod.Module(train_symbol,
                        context=ctx,
                        data_names=["data"],
                        label_names=["label"])

    mod.bind(data_shapes=data_iter.provide_data,
             label_shapes=data_iter.provide_label)
    mod.init_params(initializer=initializer,
                    arg_params=arg_dict,
                    aux_params=aux_dict,
                    allow_missing=True)

    opt_params = {
        "learning_rate": lr,
        "wd": wd,
        'momentum': momentum,
        'lr_scheduler': mx.lr_scheduler.FactorScheduler(step=lr_decay,
                                                        factor=0.1),
        'rescale_grad': 1.0 / batch_size
    }

    eval_metrics = [metrics.MultiLogisticLoss()]

    mod.fit(data_iter,
            optimizer="sgd",
            optimizer_params=opt_params,
            num_epoch=max_epoch + 1,
            epoch_end_callback=callbacks.module_checkpoint(snapshot_prefix, 1),
            batch_end_callback=callbacks.Speedometer(batch_size, frequent=10),
            eval_metric=eval_metrics,
            begin_epoch=1)
Ejemplo n.º 4
0
def main():
    DEVICE = torch.device('cuda:{}'.format(args.d))
    torch.backends.cudnn.benchmark = True

    net = create_network()
    net.to(DEVICE)
    criterion = config.create_loss_function().to(DEVICE)

    optimizer = config.create_optimizer(net.parameters())
    lr_scheduler = config.create_lr_scheduler(optimizer)

    ds_train = create_train_dataset(args.batch_size)
    ds_val = create_test_dataset(args.batch_size)

    TrainAttack = config.create_attack_method(DEVICE)
    EvalAttack = config.create_evaluation_attack_method(DEVICE)

    now_epoch = 0

    if args.auto_continue:
        args.resume = os.path.join(config.model_dir, 'last.checkpoint')
    if args.resume is not None and os.path.isfile(args.resume):
        now_epoch = load_checkpoint(args.resume, net, optimizer, lr_scheduler)

    while True:
        if now_epoch > config.num_epochs:
            break
        now_epoch = now_epoch + 1

        descrip_str = 'Training epoch:{}/{} -- lr:{}'.format(
            now_epoch, config.num_epochs,
            lr_scheduler.get_lr()[0])
        train_one_epoch(net,
                        ds_train,
                        optimizer,
                        criterion,
                        DEVICE,
                        descrip_str,
                        TrainAttack,
                        adv_coef=args.adv_coef)
        if config.eval_interval > 0 and now_epoch % config.eval_interval == 0:
            eval_one_epoch(net, ds_val, DEVICE, EvalAttack)

        lr_scheduler.step()

        save_checkpoint(now_epoch,
                        net,
                        optimizer,
                        lr_scheduler,
                        file_name=os.path.join(
                            config.model_dir,
                            'epoch-{}.checkpoint'.format(now_epoch)))
def main():
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    torch.backends.cudnn.benchmark = True

    net = create_network()
    net.to(device)
    criterion = config.create_loss_function().to(device)

    optimizer = config.create_optimizer(net.parameters())
    lr_scheduler = config.create_lr_scheduler(optimizer)

    ds_train = create_train_dataset(args.batch_size)
    ds_val = create_test_dataset(args.batch_size)

    train_attack = config.create_attack_method(device)
    eval_attack = config.create_evaluation_attack_method(device)

    now_epoch = 0

    if args.auto_continue:
        args.resume = os.path.join(config.model_dir, 'last.checkpoint')
    if args.resume is not None and os.path.isfile(args.resume):
        now_epoch = load_checkpoint(args.resume, net, optimizer, lr_scheduler)

    for i in range(now_epoch, config.num_epochs):
        # if now_epoch > config.num_epochs:
        #     break
        # now_epoch = now_epoch + 1

        descrip_str = 'Training epoch:{}/{} -- lr:{}'.format(i, config.num_epochs,
                                                             lr_scheduler.get_last_lr()[0])
        train_one_epoch(net, ds_train, optimizer, criterion, device,
                        descrip_str, train_attack, adv_coef=args.adv_coef)
        if config.eval_interval > 0 and i % config.eval_interval == 0:
            eval_one_epoch(net, ds_val, device, eval_attack)

        lr_scheduler.step()

    save_checkpoint(i, net, optimizer, lr_scheduler,
                    file_name=os.path.join(config.model_dir, 'epoch-{}.checkpoint'.format(i)))
Ejemplo n.º 6
0
LayerOneTrainer = FastGradientLayerOneTrainer(Hamiltonian_func,
                                              layer_one_optimizer,
                                              config.inner_iters, config.sigma,
                                              config.eps)

ds_train = create_train_dataset(args.batch_size)
ds_val = create_test_dataset(args.batch_size)

EvalAttack = config.create_evaluation_attack_method(DEVICE)

now_epoch = 0

if args.auto_continue:
    args.resume = os.path.join(config.model_dir, 'last.checkpoint')
if args.resume is not None and os.path.isfile(args.resume):
    now_epoch = load_checkpoint(args.resume, net, optimizer, lr_scheduler)

now_train_time = 0
while True:
    if now_epoch > config.num_epochs:
        break
    now_epoch = now_epoch + 1

    descrip_str = 'Training epoch:{}/{} -- lr:{}'.format(
        now_epoch, config.num_epochs,
        lr_scheduler.get_lr()[0])
    s_time = time.time()
    acc, yofoacc = train_one_epoch(net, ds_train, optimizer, criterion,
                                   LayerOneTrainer, config.K, DEVICE,
                                   descrip_str)
    now_train_time = now_train_time + time.time() - s_time
Ejemplo n.º 7
0
def train_sec_wrapper(ctx, epoch, lr, mem_mirror, model_prefix, symbol, class_num, workspace, init_weight_file,
                      im_root, rgb_mean, im_size, label_shrink_scale, cue_file_path,
                      epoch_size, max_epoch, batch_size, wd, momentum, lr_decay, q_fg, q_bg, SC_only=False):

    if mem_mirror:
        os.environ["MXNET_BACKWARD_DO_MIRROR"] = "1"
    arg_dict = {}
    aux_dict = {}
    outputsize = int(im_size*label_shrink_scale)
    seg_net = symbol.create_training(class_num=class_num, outputsize=outputsize, workspace=workspace, SC_only=SC_only)
    if epoch == 0:
        if not os.path.exists(init_weight_file):
            logging.warn("No model file found at %s. Start from scratch!" % init_weight_file)
        else:
            arg_dict, aux_dict, _ = misc.load_checkpoint(init_weight_file)
    else:
        arg_dict, aux_dict, _ = misc.load_checkpoint(model_prefix, epoch)
    #init weights for expand loss
    if not SC_only:
        arg_dict["fg_w"] = mx.nd.array(np.array([q_fg ** i for i in range(outputsize * outputsize - 1, -1, -1)])[None, None, :])
        arg_dict["bg_w"] = mx.nd.array(np.array([q_bg ** i for i in range(outputsize * outputsize - 1, -1, -1)])[None, :])

    data_iter = SECTrainingIter(
        im_root=im_root,
        cue_file_path=cue_file_path,
        class_num=class_num,
        rgb_mean=rgb_mean,
        im_size=im_size,
        shuffle=True,
        label_shrink_scale=label_shrink_scale,
        random_flip=True,
        data_queue_size=8,
        epoch_size=epoch_size,
        batch_size=batch_size,
        round_batch=True,
        SC_only=SC_only
    )

    initializer = mx.initializer.Normal()
    initializer.set_verbosity(True)

    if SC_only:
        mod = mx.mod.Module(seg_net, context=ctx, data_names=["data", "small_ims"], label_names=["cues"])
    else:
        mod = mx.mod.Module(seg_net, context=ctx, data_names=["data", "small_ims"], label_names=["labels", "cues"])

    mod.bind(data_shapes=data_iter.provide_data,
            label_shapes=data_iter.provide_label)
    mod.init_params(initializer=initializer, arg_params=arg_dict, aux_params=aux_dict, allow_missing=(epoch == 0))

    opt_params = {"learning_rate":lr,
                "wd": wd,
                'momentum': momentum,
                'lr_scheduler': mx.lr_scheduler.FactorScheduler(step=lr_decay, factor=0.1),
                'rescale_grad': 1.0/len(ctx)}

    if SC_only:
        eval_metrics = [metrics.SEC_seed_loss(), metrics.SEC_constrain_loss()]
    else:
        eval_metrics = [metrics.SEC_seed_loss(), metrics.SEC_constrain_loss(), metrics.SEC_expand_loss()]
    mod.fit(data_iter,
            optimizer="sgd",
            optimizer_params=opt_params,
            num_epoch=max_epoch+1,
            epoch_end_callback=callbacks.module_checkpoint(model_prefix),
            batch_end_callback=callbacks.Speedometer(batch_size, frequent=10),
            eval_metric=eval_metrics,
            begin_epoch=epoch+1)
Ejemplo n.º 8
0
        elif args.model_dir.split('/')[-1] == 'convolutional':
            from model.convolutional_VAE import VariationalAutoencoder

            reshape = False

    elif args.dataloader == 'orl_face':
        import data_loaders.orl_data_loader as data_loader

        dataloaders = data_loader.fetch_dataloader(types=['test'], data_dir=data_dir, params=params)

        # fetch model
        reshape = False
        from model.face_conv_VAE import VariationalAutoencoder

    test_dl = dataloaders['test']

    model = VariationalAutoencoder(params).cuda() if params.cuda else VariationalAutoencoder(params)
    # Reload weights from the saved file
    misc.load_checkpoint(os.path.join(log_dir, args.restore_file + log + '.pth.tar'), model)

    # Generate images
    if args.generate == 'gif':
        logging.info("\nGenerating gif\n")
        generate_images(model, model_dir, test_dl, params)
    elif args.generate == 'test':
        logging.info("\nReconstruct images form test set\n")
        reconstruct_images_form_test_set(model, test_dl, args.dataloader, params, log_dir, reshape)
    else:
        NotImplementedError("Generation not implemented, use gif or test.")

Ejemplo n.º 9
0
def train(model,
          dataloader,
          dl_type,
          optimizer,
          loss_fn,
          params,
          model_dir,
          use_swa,
          restore_file=None):
    """
    Train the model on `num_steps` batches

    Parameters
    ----------
    model
    dataloader
    dl_type
    optimizer
    loss_fn
    params
    model_dir
    reshape
    restore_file

    Returns
    -------

    """

    random_vector_for_generation = torch.randn(
        torch.Size([params.num_examples_to_generate,
                    params.latent_dim])).cuda()

    logging.info("\nTraining started.\n")
    # Add tensorboardX SummeryWriter to log training, logs will be save in model_dir directory
    run = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    log_dir = create_log_dir(model_dir, run)
    img_dir = os.path.join(log_dir, 'generated_images')
    create_dir(img_dir)
    with SummaryWriter(log_dir) as writer:
        # set model to training mode
        model.train()

        # reload weights from restore_file if specified
        if restore_file is not None:
            restore_path = os.path.join(model_dir,
                                        args.restore_file + '.pth.tar')
            logging.info("Restoring parameters from {}".format(restore_path))
            load_checkpoint(restore_path, model, optimizer)

        # number of iterations
        iterations = 0
        # Use tqdm progress bar for number of epochs
        for epoch in tqdm(range(params.num_epochs),
                          desc="Epochs: ",
                          leave=True):
            # Track the progress of the training batches
            training_progressor = trange(len(dataloader), desc="Loss")
            for i in training_progressor:
                iterations += 1
                # Fetch next batch of training samples
                if dl_type == 'orl_face':
                    train_batch = next(iter(dataloader))
                else:
                    train_batch, _ = next(iter(dataloader))

                true_samples = torch.randn(params.batch_size,
                                           params.latent_dim)

                # move to GPU if available
                if params.cuda:
                    train_batch = train_batch.cuda()
                    true_samples = true_samples.cuda()

                X_reconstructed, z = model(train_batch)
                losses = loss_fn(train_batch, X_reconstructed, true_samples, z)
                loss = losses['loss']

                # clear previous gradients, compute gradients of all variables wrt loss
                optimizer.zero_grad()
                loss.backward()

                # performs updates using calculated gradients
                optimizer.step()

                # Evaluate model parameters only once in a while
                if (i + 1) % params.save_summary_steps == 0:
                    # Log values and gradients of the model parameters (histogram summary)
                    for tag, value in model.named_parameters():
                        tag = tag.replace('.', '/')
                        writer.add_histogram(tag,
                                             value.cpu().data.numpy(),
                                             iterations)
                        writer.add_histogram(tag + '/grad',
                                             value.grad.cpu().data.numpy(),
                                             iterations)

                # Compute the loss for each iteration
                summary_batch = losses
                # log loss and/or other metrics to the writer
                for tag, value in summary_batch.items():
                    writer.add_scalar(tag, value.item(), iterations)

                # update the average loss
                training_progressor.set_description("VAE (Loss=%g)" %
                                                    round(loss.item(), 4))

            # generate images for gif
            if epoch % 1 == 0:
                generate_and_save_images(model, epoch,
                                         random_vector_for_generation, img_dir)

            # Save weights
            if epoch % 10 == 0:
                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'state_dict': model.state_dict(),
                        'optim_dict': optimizer.state_dict()
                    },
                    is_best=True,
                    checkpoint=log_dir,
                    datetime=run)

        logging.info("\n\nTraining Completed.\n\n")
        if use_swa:
            optimizer.swap_swa_sgd()

        logging.info(
            "Creating gif of images generated with gaussian latent vectors.\n")
        generate_gif(img_dir, writer)
Ejemplo n.º 10
0
import os

parser = argparse.ArgumentParser()
parser.add_argument(
    '--resume',
    '--resume',
    default='../ckpts/full-epoch32.checkpoint',
    type=str,
    metavar='PATH',
    help='path to latest checkpoint (default:../ckpts/full-epoch32.checkpoint)'
)
parser.add_argument('-d', type=int, default=0, help='Which gpu to use')
args = parser.parse_args()

DEVICE = torch.device('cuda:{}'.format(args.d))
torch.backends.cudnn.benchmark = True

net = create_network()
net.to(DEVICE)

ds_val = create_test_dataset(512)

AttackMethod = config.create_evaluation_attack_method(DEVICE)

if os.path.isfile(args.resume):
    load_checkpoint(args.resume, net)

print('Evaluating')
clean_acc, adv_acc = eval_one_epoch(net, ds_val, DEVICE, AttackMethod)
print('clean acc -- {}     adv acc -- {}'.format(clean_acc, adv_acc))
Ejemplo n.º 11
0
def test_seg_wrapper(epoch, ctx, output_folder, model_name, save_mask, save_scoremap, net_symbol, class_num, workspace,
                     snapshot_folder, max_dim, im_root, mask_root, flist_path, rgb_mean, scale_list,
                     class_names, use_crf=False, crf_params=None):

    os.environ["MXNET_CUDNN_AUTOTUNE_DEFAULT"]="0"

    crf_obj = None
    if use_crf:
        assert crf_params is not None
        crf_obj = CRF(pos_xy_std=crf_params["pos_xy_std"], pos_w=crf_params["pos_w"], bi_xy_std=crf_params["bi_xy_std"],
                      bi_rgb_std=crf_params["bi_rgb_std"], bi_w=crf_params["bi_w"], scale_factor=1.0)


    epoch_str = str(epoch)
    misc.my_mkdir(output_folder)
    misc.my_mkdir(os.path.join(output_folder, model_name + "_epoch" + epoch_str))

    if save_mask:
        misc.my_mkdir(os.path.join(output_folder, model_name + "_epoch" + epoch_str, "masks"))
    if save_scoremap:
        misc.my_mkdir(os.path.join(output_folder, model_name + "_epoch" + epoch_str, "scoremaps"))

    cmap = get_cmap()

    model_prefix = os.path.join(snapshot_folder, model_name)
    seg_net = net_symbol.create_infer(class_num, workspace)
    arg_dict, aux_dict, _ = misc.load_checkpoint(model_prefix, epoch)


    mod = mx.mod.Module(seg_net, data_names=["data", "orig_data"], label_names=[], context=ctx)
    mod.bind(data_shapes=[("data", (1, 3, max_dim, max_dim)), ("orig_data", (1, 3, max_dim, max_dim))],
             for_training=False, grad_req="null")
    initializer = mx.init.Normal()
    initializer.set_verbosity(True)
    mod.init_params(initializer=initializer, arg_params=arg_dict, aux_params=aux_dict, allow_missing=True)

    data_producer = InferenceDataProducer(
            im_root=im_root,
            mask_root=mask_root,
            flist_path=flist_path,
            rgb_mean=rgb_mean,
            scale_list=scale_list)

    nbatch = 0
    eval_metrics = [metrics.IOU(class_num, class_names)]
    logging.info("In evaluation...")

    while True:
        data = data_producer.get_data()
        if data is None:
            break
        im_list = data[0]

        label = data[1].squeeze()
        file_name = data[2]
        ori_im = data[3]
        final_scoremaps = mx.nd.zeros((class_num, label.shape[0], label.shape[1]))

        for im in im_list:
            im, orig_size = misc.pad_image(im, 8)
            mod.reshape(data_shapes=[("data", im.shape), ("orig_data", (1, 3, orig_size[0], orig_size[1]))])
            mod.forward(mx.io.DataBatch(data=[mx.nd.array(im), mx.nd.zeros((1, 3, orig_size[0], orig_size[1]))]))

            score = mx.nd.transpose(mod.get_outputs()[0].copyto(mx.cpu()), [0, 2, 3, 1])
            score = mx.nd.reshape(score, (score.shape[1], score.shape[2], score.shape[3]))
            up_score = mx.nd.transpose(mx.image.imresize(score, label.shape[1], label.shape[0], interp=1), [2, 0, 1])

            final_scoremaps += up_score

        final_scoremaps = mx.nd.log(final_scoremaps)
        final_scoremaps = final_scoremaps.asnumpy()
        if use_crf:
            assert crf_params is not None
            final_scoremaps = crf_obj.inference(ori_im.asnumpy(), final_scoremaps)

        pred_label = final_scoremaps.argmax(0)

        for eval in eval_metrics:
            eval.update(label, pred_label)

        if save_mask:
            out_img = np.uint8(pred_label)
            out_img = Image.fromarray(out_img)
            out_img.putpalette(cmap)
            out_img.save(os.path.join(output_folder, model_name + "_epoch" + epoch_str, "masks", file_name+".png"))
        if save_scoremap:
            np.save(os.path.join(output_folder, model_name + "_epoch" + epoch_str, "scoremaps", file_name), final_scoremaps)

        nbatch += 1
        if nbatch % 10 == 0:
            print "processed %dth batch" % nbatch

    logging.info("Epoch [%d]: " % epoch)
    for m in eval_metrics:
        logging.info("[overall] [%s: %.4f]" % (m.get()[0], m.get()[1]))
        if m.get_class_values() is not None:
            scores = "[perclass] ["
            for v in m.get_class_values():
                scores += "%s: %.4f\t" % (v[0], v[1])
            scores += "]"
            logging.info(scores)
Ejemplo n.º 12
0
def train_seg_wrapper(ctx, epoch, lr, model_prefix, symbol, class_num, workspace, init_weight_file,
                      im_root, mask_root, flist_path, use_g_labels, rgb_mean, crop_size, scale_range, label_shrink_scale,
                      epoch_size, max_epoch, batch_size, wd, momentum):

    arg_dict = {}
    aux_dict = {}
    if use_g_labels:
        seg_net = symbol.create_training(class_num=class_num, gweight=1.0/batch_size, workspace=workspace)
    else:
        seg_net = symbol.create_training(class_num=class_num, workspace=workspace)
    if epoch == 0:
        if not os.path.exists(init_weight_file):
            logging.warn("No model file found at %s. Start from scratch!" % init_weight_file)
        else:
            arg_dict, aux_dict, _ = misc.load_checkpoint(init_weight_file)
            param_types = ["_weight", "_bias", "_gamma", "_beta", "_moving_mean", "_moving_var"]
            #copy params for global branch
            if use_g_labels:
                for arg in arg_dict.keys():
                    for param_type in param_types:
                        if param_type in arg:
                            arg_name = arg[:arg.rfind(param_type)]
                            arg_dict[arg_name + "_g" + param_type] = arg_dict[arg].copy()
                            if arg_name in ["fc6", "fc7"]:
                                arg_dict[arg_name + "_1" + param_type] = arg_dict[arg].copy()
                                arg_dict[arg_name + "_2" + param_type] = arg_dict[arg].copy()
                                arg_dict[arg_name + "_3" + param_type] = arg_dict[arg].copy()
                                arg_dict[arg_name + "_4" + param_type] = arg_dict[arg].copy()
                            break
                for aux in aux_dict.keys():
                    for param_type in param_types:
                        if param_type in aux:
                            aux_name = aux[:aux.rfind(param_type)]
                            aux_dict[aux_name + "_g" + param_type] = aux_dict[aux].copy()
                            break
    else:
        arg_dict, aux_dict, _ = misc.load_checkpoint(model_prefix, epoch)

    data_iter = SegTrainingIter(
        im_root=im_root,
        mask_root=mask_root,
        file_list_path=flist_path,
        provide_g_labels=use_g_labels,
        class_num=class_num,
        rgb_mean=rgb_mean,
        crop_size=crop_size,
        shuffle=True,
        scale_range=scale_range,
        label_shrink_scale=label_shrink_scale,
        random_flip=True,
        data_queue_size=8,
        epoch_size=epoch_size,
        batch_size=batch_size,
        round_batch=True
    )


    initializer = mx.initializer.Normal()
    initializer.set_verbosity(True)

    if use_g_labels:
        mod = mx.mod.Module(seg_net, context=ctx, label_names=["softmax_label", "g_logistic_label"])
    else:
        mod = mx.mod.Module(seg_net, context=ctx, label_names=["softmax_label"])
    mod.bind(data_shapes=data_iter.provide_data,
            label_shapes=data_iter.provide_label)
    mod.init_params(initializer=initializer, arg_params=arg_dict, aux_params=aux_dict, allow_missing=(epoch == 0))

    opt_params = {"learning_rate":lr,
                "wd": wd,
                'momentum': momentum,
                'rescale_grad': 1.0/len(ctx)}

    if use_g_labels:
        eval_metrics = [metrics.Accuracy(), metrics.Loss(), metrics.MultiLogisticLoss(l_index=1, p_index=1)]
    else:
        eval_metrics = [metrics.Accuracy(), metrics.Loss()]
    mod.fit(data_iter,
            optimizer="sgd",
            optimizer_params=opt_params,
            num_epoch=max_epoch,
            epoch_end_callback=callbacks.module_checkpoint(model_prefix),
            batch_end_callback=callbacks.Speedometer(batch_size, frequent=10),
            eval_metric=eval_metrics,
            begin_epoch=epoch+1)
Ejemplo n.º 13
0
def main(args):
    # Seed
    torch.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    np.random.seed(args.seed)

    if args.featurize_mode:
        msg = "To perform featurization, use evaluation mode"
        assert args.evaluate and args.evaluate_video, msg
        msg = (
            f"Until we fully understand the implications of multi-worker caching, we "
            f"should avoid using multiple workers (requested {args.workers})")
        assert args.workers <= 1, msg

    # create checkpoint dir
    if not os.path.isdir(args.checkpoint):
        mkdir_p(args.checkpoint)

    # Overload print statement to log to file
    setup_verbose_logging(Path(args.checkpoint))
    logger_name = "train" if not args.evaluate else "eval"
    plog = logging.getLogger(logger_name)

    opts.print_args(args)
    opts.save_args(args, save_folder=args.checkpoint)

    if not args.debug:
        plt.switch_backend("agg")

    # create model
    plog.info(f"==> creating model '{args.arch}', out_dim={args.num_classes}")
    if args.arch == "InceptionI3d":
        model = models.__dict__[args.arch](
            num_classes=args.num_classes,
            spatiotemporal_squeeze=True,
            final_endpoint="Logits",
            name="inception_i3d",
            in_channels=3,
            dropout_keep_prob=0.5,
            num_in_frames=args.num_in_frames,
            include_embds=args.include_embds,
        )
        if args.save_features:
            msg = "Set --include_embds 1 to save_features"
            assert args.include_embds, msg
    elif args.arch == "Pose2Sign":
        model = models.Pose2Sign(num_classes=args.num_classes, )
    else:
        model = models.__dict__[args.arch](num_classes=args.num_classes, )

    device = "cuda" if torch.cuda.is_available() else "cpu"

    # adjust for opts for multi-gpu training. Note that we also apply warmup to the
    # learning rate. Can technically remove this if-statement, but leaving for now
    # to make the change explicit.
    if args.num_gpus > 1:
        num_gpus = torch.cuda.device_count()
        msg = f"Requested {args.num_gpus}, but {num_gpus} were visible"
        assert num_gpus == args.num_gpus, msg
        args.train_batch = args.train_batch * args.num_gpus
        args.test_batch = args.test_batch * args.num_gpus
        device_ids = list(range(args.num_gpus))
        args.lr = args.lr * args.num_gpus
    else:
        device_ids = [0]

    model = torch.nn.DataParallel(model, device_ids=device_ids)
    model = model.to(device)

    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=args.lr,
        momentum=args.momentum,
        weight_decay=args.weight_decay,
    )

    # optionally resume from a checkpoint
    tic = time.time()
    title = f"{args.datasetname} - {args.arch}"
    if args.resume:
        if os.path.isfile(args.resume):
            plog.info(f"=> loading checkpoint '{args.resume}'")
            checkpoint = load_checkpoint(args.resume)
            model.load_state_dict(checkpoint["state_dict"])
            optimizer.load_state_dict(checkpoint["optimizer"])
            args.start_epoch = checkpoint["epoch"]
            plog.info(
                f"=> loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']})"
            )
            logger = Logger(os.path.join(args.checkpoint, "log.txt"),
                            title=title,
                            resume=True)
            del checkpoint
        else:
            plog.info(f"=> no checkpoint found at '{args.resume}'")
            raise ValueError(f"Checkpoint not found at {args.resume}!")
    else:
        logger = Logger(os.path.join(args.checkpoint, "log.txt"), title=title)
        logger_names = ["Epoch", "LR", "train_loss", "val_loss"]
        for p in range(0, args.nloss - 1):
            logger_names.append("train_loss%d" % p)
            logger_names.append("val_loss%d" % p)
        for p in range(args.nperf):
            logger_names.append("train_perf%d" % p)
            logger_names.append("val_perf%d" % p)

        logger.set_names(logger_names)

    if args.pretrained:
        load_checkpoint_flexible(model, optimizer, args, plog)

    param_count = humanize.intword(sum(p.numel() for p in model.parameters()))
    plog.info(f"    Total params: {param_count}")
    duration = time.strftime("%Hh%Mm%Ss", time.gmtime(time.time() - tic))
    plog.info(f"Loaded parameters for model in {duration}")

    mdl = MultiDataLoader(
        train_datasets=args.datasetname,
        val_datasets=args.datasetname,
    )
    train_loader, val_loader, meanstd = mdl._get_loaders(args)

    train_mean = meanstd[0]
    train_std = meanstd[1]
    val_mean = meanstd[2]
    val_std = meanstd[3]

    save_feature_dir = args.checkpoint
    save_fig_dir = Path(args.checkpoint) / "figs"
    if args.featurize_mode:
        save_feature_dir = Path(
            args.checkpoint) / "filtered" / args.featurize_mask
        save_feature_dir.mkdir(exist_ok=True, parents=True)
        save_fig_dir = Path(args.checkpoint) / "figs" / args.featurize_mask
        save_fig_dir.mkdir(exist_ok=True, parents=True)

    # Define criterion
    criterion = torch.nn.CrossEntropyLoss(reduction="mean")
    criterion = criterion.to(device)

    if args.evaluate or args.evaluate_video:
        plog.info("\nEvaluation only")
        loss, acc = do_epoch(
            "val",
            val_loader,
            model,
            criterion,
            num_classes=args.num_classes,
            debug=args.debug,
            checkpoint=args.checkpoint,
            mean=val_mean,
            std=val_std,
            feature_dim=args.feature_dim,
            save_logits=True,
            save_features=args.save_features,
            num_figs=args.num_figs,
            topk=args.topk,
            save_feature_dir=save_feature_dir,
            save_fig_dir=save_fig_dir,
        )
        if args.featurize_mode:
            plog.info(f"Featurizing without metric evaluation")
            return

        # Summarize/save results
        evaluate.evaluate(args, val_loader.dataset, plog)

        logger_epoch = [0, 0]
        for p in range(len(loss)):
            logger_epoch.append(float(loss[p].avg))
            logger_epoch.append(float(loss[p].avg))
        for p in range(len(acc)):
            logger_epoch.append(float(acc[p].avg))
            logger_epoch.append(float(acc[p].avg))
        # append logger file
        logger.append(logger_epoch)

        return

    lr = args.lr
    for epoch in range(args.start_epoch, args.epochs):
        lr = adjust_learning_rate(optimizer,
                                  epoch,
                                  lr,
                                  args.schedule,
                                  args.gamma,
                                  num_gpus=args.num_gpus)
        plog.info("\nEpoch: %d | LR: %.8f" % (epoch + 1, lr))

        # train for one epoch
        train_loss, train_perf = do_epoch(
            "train",
            train_loader,
            model,
            criterion,
            epochno=epoch,
            optimizer=optimizer,
            num_classes=args.num_classes,
            debug=args.debug,
            checkpoint=args.checkpoint,
            mean=train_mean,
            std=train_std,
            feature_dim=args.feature_dim,
            save_logits=False,
            save_features=False,
            num_figs=args.num_figs,
            topk=args.topk,
            save_feature_dir=save_feature_dir,
            save_fig_dir=save_fig_dir,
        )

        # evaluate on validation set
        valid_loss, valid_perf = do_epoch(
            "val",
            val_loader,
            model,
            criterion,
            epochno=epoch,
            num_classes=args.num_classes,
            debug=args.debug,
            checkpoint=args.checkpoint,
            mean=val_mean,
            std=val_std,
            feature_dim=args.feature_dim,
            save_logits=False,
            save_features=False,
            num_figs=args.num_figs,
            topk=args.topk,
            save_feature_dir=save_feature_dir,
            save_fig_dir=save_fig_dir,
        )

        logger_epoch = [epoch + 1, lr]
        for p in range(len(train_loss)):
            logger_epoch.append(float(train_loss[p].avg))
            logger_epoch.append(float(valid_loss[p].avg))
        for p in range(len(train_perf)):
            logger_epoch.append(float(train_perf[p].avg))
            logger_epoch.append(float(valid_perf[p].avg))
        # append logger file
        logger.append(logger_epoch)

        # save checkpoint
        save_checkpoint(
            {
                "epoch": epoch + 1,
                "arch": args.arch,
                "state_dict": model.state_dict(),
                "optimizer": optimizer.state_dict(),
            },
            checkpoint=args.checkpoint,
            snapshot=args.snapshot,
        )

        plt.clf()
        plt.subplot(121)
        logger.plot(["train_loss", "val_loss"])
        plt.subplot(122)
        logger.plot(["train_perf0", "val_perf0"])
        savefig(os.path.join(args.checkpoint, "log.pdf"))

    logger.close()