def main(): """ Main body of script. """ args = parse_args() args.seed = init_rand(seed=args.seed) _, log_file_exist = initialize_logging( logging_dir_path=args.save_dir, logging_file_name=args.logging_file_name, script_args=args, log_packages=args.log_packages, log_pip_packages=args.log_pip_packages) ctx, batch_size = prepare_mx_context(num_gpus=args.num_gpus, batch_size=args.batch_size) ds_metainfo = get_dataset_metainfo(dataset_name=args.dataset) ds_metainfo.update(args=args) use_teacher = (args.teacher_models is not None) and (args.teacher_models.strip() != "") net = prepare_model( model_name=args.model, use_pretrained=args.use_pretrained, pretrained_model_file_path=args.resume.strip(), dtype=args.dtype, net_extra_kwargs=ds_metainfo.train_net_extra_kwargs, tune_layers=args.tune_layers, classes=args.num_classes, in_channels=args.in_channels, do_hybridize=(not args.not_hybridize), initializer=get_initializer(initializer_name=args.initializer), ctx=ctx) assert (hasattr(net, "classes")) num_classes = net.classes teacher_net = None discrim_net = None discrim_loss_func = None if use_teacher: teacher_nets = [] for teacher_model in args.teacher_models.split(","): teacher_net = prepare_model( model_name=teacher_model.strip(), use_pretrained=True, pretrained_model_file_path="", dtype=args.dtype, net_extra_kwargs=ds_metainfo.train_net_extra_kwargs, do_hybridize=(not args.not_hybridize), ctx=ctx) assert (teacher_net.classes == net.classes) assert (teacher_net.in_size == net.in_size) teacher_nets.append(teacher_net) if len(teacher_nets) > 0: teacher_net = Concurrent(stack=True, prefix="", branches=teacher_nets) for k, v in teacher_net.collect_params().items(): v.grad_req = "null" if not args.not_discriminator: discrim_net = MealDiscriminator() discrim_net.cast(args.dtype) if not args.not_hybridize: discrim_net.hybridize(static_alloc=True, static_shape=True) discrim_net.initialize(mx.init.MSRAPrelu(), ctx=ctx) for k, v in discrim_net.collect_params().items(): v.lr_mult = args.dlr_factor discrim_loss_func = MealAdvLoss() train_data = get_train_data_source(ds_metainfo=ds_metainfo, batch_size=batch_size, num_workers=args.num_workers) val_data = get_val_data_source(ds_metainfo=ds_metainfo, batch_size=batch_size, num_workers=args.num_workers) batch_fn = get_batch_fn(ds_metainfo=ds_metainfo) num_training_samples = len( train_data._dataset ) if not ds_metainfo.use_imgrec else ds_metainfo.num_training_samples trainer, lr_scheduler = prepare_trainer( net=net, optimizer_name=args.optimizer_name, wd=args.wd, momentum=args.momentum, lr_mode=args.lr_mode, lr=args.lr, lr_decay_period=args.lr_decay_period, lr_decay_epoch=args.lr_decay_epoch, lr_decay=args.lr_decay, target_lr=args.target_lr, poly_power=args.poly_power, warmup_epochs=args.warmup_epochs, warmup_lr=args.warmup_lr, warmup_mode=args.warmup_mode, batch_size=batch_size, num_epochs=args.num_epochs, num_training_samples=num_training_samples, dtype=args.dtype, gamma_wd_mult=args.gamma_wd_mult, beta_wd_mult=args.beta_wd_mult, bias_wd_mult=args.bias_wd_mult, state_file_path=args.resume_state) if args.save_dir and args.save_interval: param_names = ds_metainfo.val_metric_capts + ds_metainfo.train_metric_capts + [ "Train.Loss", "LR" ] lp_saver = TrainLogParamSaver( checkpoint_file_name_prefix="{}_{}".format(ds_metainfo.short_label, args.model), last_checkpoint_file_name_suffix="last", best_checkpoint_file_name_suffix=None, last_checkpoint_dir_path=args.save_dir, best_checkpoint_dir_path=None, last_checkpoint_file_count=2, best_checkpoint_file_count=2, checkpoint_file_save_callback=save_params, checkpoint_file_exts=(".params", ".states"), save_interval=args.save_interval, num_epochs=args.num_epochs, param_names=param_names, acc_ind=ds_metainfo.saver_acc_ind, # bigger=[True], # mask=None, score_log_file_path=os.path.join(args.save_dir, "score.log"), score_log_attempt_value=args.attempt, best_map_log_file_path=os.path.join(args.save_dir, "best_map.log")) else: lp_saver = None val_metric = get_composite_metric(ds_metainfo.val_metric_names, ds_metainfo.val_metric_extra_kwargs) train_metric = get_composite_metric(ds_metainfo.train_metric_names, ds_metainfo.train_metric_extra_kwargs) loss_metrics = [LossValue(name="loss"), LossValue(name="dloss")] loss_kwargs = { "sparse_label": (not (args.mixup or args.label_smoothing) and not (use_teacher and (teacher_net is not None))) } if ds_metainfo.loss_extra_kwargs is not None: loss_kwargs.update(ds_metainfo.loss_extra_kwargs) loss_func = get_loss(ds_metainfo.loss_name, loss_kwargs) train_net(batch_size=batch_size, num_epochs=args.num_epochs, start_epoch1=args.start_epoch, train_data=train_data, val_data=val_data, batch_fn=batch_fn, data_source_needs_reset=ds_metainfo.use_imgrec, dtype=args.dtype, net=net, teacher_net=teacher_net, discrim_net=discrim_net, trainer=trainer, lr_scheduler=lr_scheduler, lp_saver=lp_saver, log_interval=args.log_interval, mixup=args.mixup, mixup_epoch_tail=args.mixup_epoch_tail, label_smoothing=args.label_smoothing, num_classes=num_classes, grad_clip_value=args.grad_clip, batch_size_scale=args.batch_size_scale, val_metric=val_metric, train_metric=train_metric, loss_metrics=loss_metrics, loss_func=loss_func, discrim_loss_func=discrim_loss_func, ctx=ctx)
def main(): """ Main body of script. """ args = parse_args() args.seed = init_rand(seed=args.seed) _, log_file_exist = initialize_logging( logging_dir_path=args.save_dir, logging_file_name=args.logging_file_name, script_args=args, log_packages=args.log_packages, log_pip_packages=args.log_pip_packages) ctx, batch_size = prepare_mx_context(num_gpus=args.num_gpus, batch_size=args.batch_size) ds_metainfo = get_dataset_metainfo(dataset_name=args.dataset) ds_metainfo.update(args=args) net = prepare_model( model_name=args.model, use_pretrained=args.use_pretrained, pretrained_model_file_path=args.resume.strip(), dtype=args.dtype, net_extra_kwargs=ds_metainfo.train_net_extra_kwargs, load_ignore_extra=True, # Ignore loading leftover skip connections - LIV tune_layers=args.tune_layers, classes=args.num_classes, in_channels=args.in_channels, do_hybridize=(not args.not_hybridize), initializer=get_initializer(initializer_name=args.initializer), ctx=ctx) assert (hasattr(net, "classes")) num_classes = net.classes train_data = get_train_data_source(ds_metainfo=ds_metainfo, batch_size=batch_size, num_workers=args.num_workers) val_data = get_val_data_source(ds_metainfo=ds_metainfo, batch_size=batch_size, num_workers=args.num_workers) batch_fn = get_batch_fn(ds_metainfo=ds_metainfo) num_training_samples = len( train_data._dataset ) if not ds_metainfo.use_imgrec else ds_metainfo.num_training_samples trainer, lr_scheduler = prepare_trainer( net=net, optimizer_name=args.optimizer_name, wd=args.wd, momentum=args.momentum, lr_mode=args.lr_mode, lr=args.lr, lr_decay_period=args.lr_decay_period, lr_decay_epoch=args.lr_decay_epoch, lr_decay=args.lr_decay, target_lr=args.target_lr, poly_power=args.poly_power, warmup_epochs=args.warmup_epochs, warmup_lr=args.warmup_lr, warmup_mode=args.warmup_mode, batch_size=batch_size, num_epochs=args.num_epochs, num_training_samples=num_training_samples, dtype=args.dtype, gamma_wd_mult=args.gamma_wd_mult, beta_wd_mult=args.beta_wd_mult, bias_wd_mult=args.bias_wd_mult, state_file_path=args.resume_state) if args.save_dir and args.save_interval: param_names = ds_metainfo.val_metric_capts + ds_metainfo.train_metric_capts + [ "Train.Loss", "LR" ] lp_saver = TrainLogParamSaver( checkpoint_file_name_prefix="{}_{}".format(ds_metainfo.short_label, args.model), last_checkpoint_file_name_suffix="last", best_checkpoint_file_name_suffix=None, last_checkpoint_dir_path=args.save_dir, best_checkpoint_dir_path=None, last_checkpoint_file_count=2, best_checkpoint_file_count=2, checkpoint_file_save_callback=save_params, checkpoint_file_exts=(".params", ".states"), save_interval=args.save_interval, num_epochs=args.num_epochs, param_names=param_names, acc_ind=ds_metainfo.saver_acc_ind, # bigger=[True], # mask=None, score_log_file_path=os.path.join(args.save_dir, "score.log"), score_log_attempt_value=args.attempt, best_map_log_file_path=os.path.join(args.save_dir, "best_map.log")) else: lp_saver = None val_metric = get_composite_metric(ds_metainfo.val_metric_names, ds_metainfo.val_metric_extra_kwargs) train_metric = get_composite_metric(ds_metainfo.train_metric_names, ds_metainfo.train_metric_extra_kwargs) loss_kwargs = {"sparse_label": not (args.mixup or args.label_smoothing)} if ds_metainfo.loss_extra_kwargs is not None: loss_kwargs.update(ds_metainfo.loss_extra_kwargs) loss_func = get_loss(ds_metainfo.loss_name, loss_kwargs) print(net) train_net(batch_size=batch_size, num_epochs=args.num_epochs, start_epoch1=args.start_epoch, train_data=train_data, val_data=val_data, batch_fn=batch_fn, data_source_needs_reset=ds_metainfo.use_imgrec, dtype=args.dtype, net=net, trainer=trainer, lr_scheduler=lr_scheduler, lp_saver=lp_saver, log_interval=args.log_interval, mixup=args.mixup, mixup_epoch_tail=args.mixup_epoch_tail, label_smoothing=args.label_smoothing, num_classes=num_classes, grad_clip_value=args.grad_clip, batch_size_scale=args.batch_size_scale, val_metric=val_metric, train_metric=train_metric, loss_func=loss_func, ctx=ctx)