def __init__(self, state_space, act_n, quantile_dim, num_quantiles, hidden_dim, num_hidden, optim_params): """ Rainbow Recurrent IQN IQN: https://arxiv.org/pdf/1806.06923.pdf R2D2: https://openreview.net/pdf?id=r1lyTjAqYX R2D3: https://arxiv.org/abs/1909.01387 """ nn.Module.__init__(self) self.online = Model(state_space, act_n, quantile_dim, num_quantiles, hidden_dim, num_hidden) self.target = deepcopy(self.online) self.loss_func = nn.SmoothL1Loss(reduction="mean") self.optim = RAdam(self.online.parameters(), **optim_params)
elif args.arch == 'stacked': model = stacked_transformer_model else: raise TypeError if args.optimizer.lower() == 'adam': optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) elif args.optimizer.lower() == 'sgd': optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) elif args.optimizer.lower() == 'radam': optimizer = RAdam(model.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) else: raise TypeError iterator = BucketIterator(batch_size=args.batch, sorting_keys=[("source", "num_tokens")]) iterator.index_with(vocab) #scheduler = _PyTorchLearningRateSchedulerWrapper(ReduceLROnPlateau(optimizer, patience=4)) if torch.cuda.is_available(): cuda_device = 0 model = model.cuda(cuda_device) print('using gpu') else: cuda_device = -1
def __init__(self, encoder, decoder, optimizer_params={}, amp_params={}, n_jobs=0, rank=0): lr = optimizer_params.get('lr', 1e-3) weight_decay = optimizer_params.get('weight_decay', 0) warmap = optimizer_params.get('warmap', 100) amsgrad = optimizer_params.get('amsgrad', False) opt_level = amp_params.get('opt_level', 'O0') loss_scale = amp_params.get('loss_scale', None) self.device = torch.device('cuda:' + str(rank)) self.encoder = encoder.to(self.device) #self.decoder = decoder.to(self.device) self.num_classes = decoder.num_classes self.mse_critetion = nn.L1Loss() self.ce_criterion = LabelSmoothingLoss(self.num_classes, smoothing=0.1, reduction='none').to( self.device) self.vat_criterion = VATLoss() self.cutmix = CutMix(self.num_classes) param_optimizer = list(self.encoder.named_parameters() ) #+ list(self.decoder.named_parameters()) no_decay = ['bn', 'bias'] optimizer_grouped_parameters = [{ 'params': [ p for n, p in param_optimizer if not any(nd in n for nd in no_decay) ], 'weight_decay': weight_decay }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] self.optimizer = RAdam(optimizer_grouped_parameters, lr=lr, weight_decay=weight_decay) self.is_master = torch.distributed.get_rank() == 0 torch.cuda.set_device(rank) [self.encoder ], self.optimizer = apex.amp.initialize([self.encoder], self.optimizer, opt_level=opt_level, loss_scale=loss_scale, verbosity=1) self.scheduler = StepLR(self.optimizer, step_size=20, gamma=0.5) self.encoder = apex.parallel.DistributedDataParallel( self.encoder, delay_allreduce=True) #self.decoder = apex.parallel.DistributedDataParallel(self.decoder, delay_allreduce=True) self.last_epoch = 0 self.n_jobs = n_jobs
class Trainer: def __init__(self, encoder, decoder, optimizer_params={}, amp_params={}, n_jobs=0, rank=0): lr = optimizer_params.get('lr', 1e-3) weight_decay = optimizer_params.get('weight_decay', 0) warmap = optimizer_params.get('warmap', 100) amsgrad = optimizer_params.get('amsgrad', False) opt_level = amp_params.get('opt_level', 'O0') loss_scale = amp_params.get('loss_scale', None) self.device = torch.device('cuda:' + str(rank)) self.encoder = encoder.to(self.device) #self.decoder = decoder.to(self.device) self.num_classes = decoder.num_classes self.mse_critetion = nn.L1Loss() self.ce_criterion = LabelSmoothingLoss(self.num_classes, smoothing=0.1, reduction='none').to( self.device) self.vat_criterion = VATLoss() self.cutmix = CutMix(self.num_classes) param_optimizer = list(self.encoder.named_parameters() ) #+ list(self.decoder.named_parameters()) no_decay = ['bn', 'bias'] optimizer_grouped_parameters = [{ 'params': [ p for n, p in param_optimizer if not any(nd in n for nd in no_decay) ], 'weight_decay': weight_decay }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] self.optimizer = RAdam(optimizer_grouped_parameters, lr=lr, weight_decay=weight_decay) self.is_master = torch.distributed.get_rank() == 0 torch.cuda.set_device(rank) [self.encoder ], self.optimizer = apex.amp.initialize([self.encoder], self.optimizer, opt_level=opt_level, loss_scale=loss_scale, verbosity=1) self.scheduler = StepLR(self.optimizer, step_size=20, gamma=0.5) self.encoder = apex.parallel.DistributedDataParallel( self.encoder, delay_allreduce=True) #self.decoder = apex.parallel.DistributedDataParallel(self.decoder, delay_allreduce=True) self.last_epoch = 0 self.n_jobs = n_jobs def _train_epoch(self, train_dataloader): if self.is_master: pbar = tqdm(desc=f'Train, epoch #{self.last_epoch}', total=len(train_dataloader)) self.encoder.train() #self.decoder.train() sum_loss, cls_loss = AvgMeter(), AvgMeter() for images, labels in train_dataloader: images, labels, shuffled_labels, l = self.cutmix(images, labels) images = images.to(self.device) labels = labels.to(self.device) shuffled_labels = shuffled_labels.to(self.device) l = l.to(self.device) self.optimizer.zero_grad() #loss_vat = self.vat_criterion(self.encoder, images) label_preds = self.encoder(images) #reconsts_l = self.decoder(latents, labels) #with disable_tracking_bn_stats(self.encoder): # latents_l, label_preds_l = self.encoder(reconsts_l) #labels_r = torch.randint_like(labels, low=0, high=self.num_classes) #reconsts_r = self.decoder(latents, labels_r) #with disable_tracking_bn_stats(self.encoder): # latents_r, label_preds_r = self.encoder(reconsts_r) loss_c = (l * self.ce_criterion(label_preds, labels) + (1 - l) * self.ce_criterion(label_preds, shuffled_labels)).mean() #loss_r = self.mse_critetion(reconsts_l, images) #loss_e = self.ce_criterion(label_preds_r, labels_r) #loss_i = self.mse_critetion(latents_l, latents_r) losses = loss_c #+ loss_vat # + loss_r + loss_e + loss_i with apex.amp.scale_loss(losses, self.optimizer) as scaled_loss: scaled_loss.backward() self.optimizer.step() sum_loss.update(losses.item()) cls_loss.update(loss_c.item()) info_tensor = torch.tensor([sum_loss(), cls_loss()], device=self.device) torch.distributed.reduce(info_tensor, dst=0) if self.is_master: info_tensor = info_tensor / torch.distributed.get_world_size() pbar.update(1) pbar.set_postfix({ 'sum_loss': info_tensor[0].item(), 'cls_loss': info_tensor[1].item() }) self.scheduler.step() def _test_epoch(self, test_dataloader): with torch.no_grad(): if self.is_master: pbar = tqdm(desc=f'Test, epoch #{self.last_epoch}', total=len(test_dataloader)) self.encoder.eval() loss, acc, quality_metric = AvgMeter(), AvgMeter(), 0 for images, labels in test_dataloader: images = images.to(self.device) labels = labels.to(self.device) label_preds = self.encoder(images) loss_val = self.ce_criterion(label_preds, labels).mean() acc_val = (torch.argmax(label_preds, dim=-1) == labels).float().mean() loss.update(loss_val.item()) acc.update(acc_val.item()) info_tensor = torch.tensor([loss(), acc()], device=self.device) torch.distributed.reduce(info_tensor, dst=0) if self.is_master: info_tensor = info_tensor / torch.distributed.get_world_size( ) quality_metric = info_tensor[1].item() pbar.update(1) pbar.set_postfix({ 'loss': info_tensor[0].item(), 'acc': info_tensor[1].item() }) return quality_metric def _save_checkpoint(self, checkpoint_path): os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True) torch.save(self.encoder.module.state_dict(), checkpoint_path) def train(self, train_data, n_epochs, batch_size, test_data=None, last_checkpoint_path=None, best_checkpoint_path=None): num_replicas = torch.distributed.get_world_size() rank = torch.distributed.get_rank() batch_size = batch_size // num_replicas train_sampler = DistributedSampler(train_data, shuffle=True, num_replicas=num_replicas, rank=rank) train_dataloader = DataLoader(train_data, batch_size=batch_size, sampler=train_sampler, num_workers=self.n_jobs) if test_data is not None: test_sampler = DistributedSampler(test_data, shuffle=False, num_replicas=num_replicas, rank=rank) test_dataloader = DataLoader(test_data, batch_size=batch_size, sampler=test_sampler, num_workers=self.n_jobs) best_metric = float('-inf') for epoch in range(n_epochs): torch.cuda.empty_cache() self._train_epoch(train_dataloader) if last_checkpoint_path is not None and self.is_master: self._save_checkpoint(last_checkpoint_path) if test_data is not None: torch.cuda.empty_cache() metric = self._test_epoch(test_dataloader) if best_checkpoint_path is not None and self.is_master: if metric > best_metric: best_metric = metric self._save_checkpoint(best_checkpoint_path) self.last_epoch += 1
def create_optimizer(cfg, model, filter_bias_and_bn=True): opt_lower = cfg.SOLVER.OPTIMIZER.lower() lr = cfg.SOLVER.BASE_LR weight_decay = cfg.SOLVER.WEIGHT_DECAY momentum = cfg.SOLVER.MOMENTUM if 'adamw' in opt_lower or 'radam' in opt_lower: # Compensate for the way current AdamW and RAdam optimizers apply LR to the weight-decay # I don't believe they follow the paper or original Torch7 impl which schedules weight # decay based on the ratio of current_lr/initial_lr weight_decay /= lr if weight_decay and filter_bias_and_bn: parameters = [] for key, value in model.named_parameters(): if not value.requires_grad: continue filtered_lr = lr filtered_weight_decay = weight_decay if "bias" in key: filtered_lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR filtered_weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS parameters += [{ "params": [value], "lr": filtered_lr, "weight_decay": filtered_weight_decay }] weight_decay = 0. else: parameters = model.parameters() opt_split = opt_lower.split('_') opt_name = opt_split[-1] if opt_name == 'sgd': optimizer = optim.SGD(parameters, lr=lr, momentum=momentum, weight_decay=weight_decay, nesterov=True) elif opt_name == 'adam': optimizer = optim.Adam(parameters, lr=lr, weight_decay=weight_decay) elif opt_name == 'adamw': optimizer = AdamW(parameters, lr=lr, weight_decay=weight_decay) elif opt_name == 'nadam': optimizer = Nadam(parameters, lr=lr, weight_decay=weight_decay) elif opt_name == 'radam': optimizer = RAdam(parameters, lr=lr, weight_decay=weight_decay) elif opt_name == 'adadelta': optimizer = optim.Adadelta(parameters, lr=lr, weight_decay=weight_decay) elif opt_name == 'rmsprop': optimizer = optim.RMSprop(parameters, lr=lr, alpha=0.9, momentum=momentum, weight_decay=weight_decay) elif opt_name == 'rmsproptf': optimizer = RMSpropTF(parameters, lr=lr, alpha=0.9, momentum=momentum, weight_decay=weight_decay) elif opt_name == 'novograd': optimizer = NovoGrad(parameters, lr=lr, weight_decay=weight_decay) elif opt_name == 'nvnovograd': optimizer = NvNovoGrad(parameters, lr=lr, weight_decay=weight_decay) else: raise ValueError("Invalid optimizer") if len(opt_split) > 1: if opt_split[0] == 'lookahead': optimizer = Lookahead(optimizer) return optimizer
def main(): args = parse_args() conf = Config(args.conf) data_dir = conf.data_dir fold_id = conf.fold_id workspace = Workspace(conf.run_id).setup() workspace.save_conf(args.conf) workspace.log(f'{conf.to_dict()}') torch.cuda.set_device(0) if conf.use_augmentor: if conf.augmentor_type == 'v1': augmentor = create_augmentor_v1( enable_random_morph=conf.enable_random_morph) elif conf.augmentor_type == 'v2': augmentor = create_augmentor_v2( enable_random_morph=conf.enable_random_morph, invert_color=conf.invert_color) elif conf.augmentor_type == 'v3': if conf.input_size_tuple: input_size = tuple(conf.input_size_tuple) else: input_size = (conf.input_size, conf.input_size) if conf.input_size else \ (SOURCE_IMAGE_HEIGHT, SOURCE_IMAGE_WIDTH) augmentor = create_augmentor_v3( input_size, enable_random_morph=conf.enable_random_morph, invert_color=conf.invert_color) else: raise ValueError(conf.augmentor_type) workspace.log(f'Use augmentor: {conf.augmentor_type}') else: augmentor = None if not conf.input_size_tuple and conf.input_size == 0: train_transformer = create_transformer_v1(augmentor=augmentor) val_transformer = create_testing_transformer_v1() workspace.log('Input size: default') else: if conf.input_size_tuple: input_size = tuple(conf.input_size_tuple) else: input_size = (conf.input_size, conf.input_size) train_transformer = create_transformer_v1(input_size=input_size, augmentor=augmentor) val_transformer = create_testing_transformer_v1(input_size=input_size) workspace.log(f'Input size: {input_size}') train_dataset, val_dataset = bengali_dataset( data_dir, fold_id=fold_id, train_transformer=train_transformer, val_transformer=val_transformer, invert_color=conf.invert_color, n_channel=conf.n_channel, use_grapheme_code=conf.use_grapheme_code, logger=workspace.logger) workspace.log(f'#train={len(train_dataset)}, #val={len(val_dataset)}') train_dataset.set_low_freq_groups(n_class=conf.n_class_low_freq) if conf.sampler_type == 'pk': sampler = PKSampler(train_dataset, n_iter_per_epoch=conf.n_iter_per_epoch, p=conf.batch_p, k=conf.batch_k) train_loader = DataLoader(train_dataset, shuffle=False, num_workers=8, pin_memory=True, batch_sampler=sampler) workspace.log(f'{sampler} is enabled') workspace.log(f'Real batch_size={sampler.batch_size}') elif conf.sampler_type == 'random+append': batch_sampler = LowFreqSampleMixinBatchSampler( train_dataset, conf.batch_size, n_low_freq_samples=conf.n_low_freq_samples, drop_last=True) train_loader = DataLoader(train_dataset, shuffle=False, num_workers=8, pin_memory=True, batch_sampler=batch_sampler) workspace.log(f'{batch_sampler} is enabled') workspace.log(f'Real batch_size={batch_sampler.batch_size}') elif conf.sampler_type == 'random': train_loader = DataLoader(train_dataset, batch_size=conf.batch_size, shuffle=True, num_workers=8, pin_memory=True, drop_last=True) else: raise ValueError(f'Invalid sampler_type: {conf.sampler_type}') val_loader = DataLoader(val_dataset, batch_size=conf.batch_size, shuffle=False, num_workers=8, pin_memory=True) workspace.log(f'Create init model: arch={conf.arch}') model = create_init_model(conf.arch, pretrained=True, pooling=conf.pooling_type, dim=conf.feat_dim, use_maxblurpool=conf.use_maxblurpool, remove_last_stride=conf.remove_last_stride, n_channel=conf.n_channel) if conf.weight_file: pretrained_weight = torch.load(conf.weight_file, map_location='cpu') result = model.load_state_dict(pretrained_weight) workspace.log(f'Pretrained weights were loaded: {conf.weight_file}') workspace.log(result) model = model.cuda() sub_models = [] criterion_g = get_criterion(conf.loss_type_g, weight=train_dataset.get_class_weights_g(), rate=conf.ohem_rate) workspace.log(f'Loss type (g): {conf.loss_type_g}') criterion_v = get_criterion(conf.loss_type_v, weights=train_dataset.get_class_weights_v(), rate=conf.ohem_rate) workspace.log(f'Loss type (v): {conf.loss_type_v}') criterion_c = get_criterion(conf.loss_type_c, weights=train_dataset.get_class_weights_c(), rate=conf.ohem_rate) workspace.log(f'Loss type (c): {conf.loss_type_c}') if conf.loss_type_feat_g != 'none': assert isinstance( model, (M.BengaliResNet34V3, M.BengaliResNet34V4, M.BengaliResNet34AGeMV4, M.BengaliSEResNeXt50V4, M.BengaliEfficientNetB0V4, M.BengaliEfficientNetB3V4)) criterion_feat_g = get_criterion(conf.loss_type_feat_g, dim=model.multihead.head_g.dim, n_class=168, s=conf.af_scale_g) workspace.log(f'Loss type (fg): {conf.loss_type_feat_g}') if conf.loss_type_feat_g in ('af', ): sub_models.append(criterion_feat_g) workspace.log('Add criterion_feat_g to sub model') else: criterion_feat_g = None if conf.loss_type_feat_v != 'none': assert isinstance( model, (M.BengaliResNet34V3, M.BengaliResNet34V4, M.BengaliResNet34AGeMV4, M.BengaliSEResNeXt50V4, M.BengaliEfficientNetB0V4, M.BengaliEfficientNetB3V4)) criterion_feat_v = get_criterion(conf.loss_type_feat_v, dim=model.multihead.head_v.dim, n_class=11, s=conf.af_scale_v) workspace.log(f'Loss type (fv): {conf.loss_type_feat_v}') if conf.loss_type_feat_v in ('af', ): sub_models.append(criterion_feat_v) workspace.log('Add criterion_feat_v to sub model') else: criterion_feat_v = None if conf.loss_type_feat_c != 'none': assert isinstance( model, (M.BengaliResNet34V3, M.BengaliResNet34V4, M.BengaliResNet34AGeMV4, M.BengaliSEResNeXt50V4, M.BengaliEfficientNetB0V4, M.BengaliEfficientNetB3V4)) criterion_feat_c = get_criterion(conf.loss_type_feat_c, dim=model.multihead.head_c.dim, n_class=7, s=conf.af_scale_c) workspace.log(f'Loss type (fc): {conf.loss_type_feat_c}') if conf.loss_type_feat_c in ('af', ): sub_models.append(criterion_feat_c) workspace.log('Add criterion_feat_c to sub model') else: criterion_feat_c = None if conf.use_grapheme_code: workspace.log('Use grapheme code classifier') grapheme_classifier = nn.Sequential(nn.BatchNorm1d(168 + 11 + 7), nn.Linear(168 + 11 + 7, 1295)) grapheme_classifier = grapheme_classifier.cuda() grapheme_classifier.train() sub_models.append(grapheme_classifier) criterion_grapheme = L.OHEMCrossEntropyLoss().cuda() else: grapheme_classifier = None criterion_grapheme = None parameters = [{'params': model.parameters()}] + \ [{'params': sub_model.parameters()} for sub_model in sub_models] if conf.optimizer_type == 'adam': optimizer = torch.optim.Adam(parameters, lr=conf.lr) elif conf.optimizer_type == 'sgd': optimizer = torch.optim.SGD(parameters, lr=conf.lr, momentum=0.9, weight_decay=1e-4) elif conf.optimizer_type == 'ranger': optimizer = Ranger(parameters, lr=conf.lr, weight_decay=1e-4) elif conf.optimizer_type == 'radam': optimizer = RAdam(parameters, lr=conf.lr, weight_decay=1e-4) else: raise ValueError(conf.optimizer_type) workspace.log(f'Optimizer type: {conf.optimizer_type}') if conf.use_apex: workspace.log('Apex initialization') _models, optimizer = amp.initialize([model] + sub_models, optimizer, opt_level=conf.apex_opt_level) if len(_models) == 1: model = _models[0] else: model = _models[0] criterion_feat_g = _models[1] criterion_feat_v = _models[2] criterion_feat_c = _models[3] workspace.log('Initialized by Apex') workspace.log(f'{optimizer.__class__.__name__}') for m in _models: workspace.log(f'{m.__class__.__name__}') if conf.scheduler_type == 'cosanl': scheduler = CosineLRWithRestarts( optimizer, conf.batch_size, len(train_dataset), restart_period=conf.cosanl_restart_period, t_mult=conf.cosanl_t_mult) workspace.log(f'restart_period={scheduler.restart_period}') workspace.log(f't_mult={scheduler.t_mult}') elif conf.scheduler_type == 'rop': scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, patience=conf.rop_patience, mode='max', factor=conf.rop_factor, min_lr=1e-6, verbose=True) else: raise ValueError(conf.scheduler_type) train(model, train_loader, val_loader, optimizer, criterion_g, criterion_v, criterion_c, criterion_feat_g, criterion_feat_v, criterion_feat_c, workspace, scheduler=scheduler, n_epoch=conf.n_epoch, cutmix_prob=conf.cutmix_prob, mixup_prob=conf.mixup_prob, freeze_bn_epochs=conf.freeze_bn_epochs, feat_loss_weight=conf.feat_loss_weight, use_apex=conf.use_apex, decrease_ohem_rate=conf.decrease_ohem_rate, use_grapheme_code=conf.use_grapheme_code, grapheme_classifier=grapheme_classifier, criterion_grapheme=criterion_grapheme, final_ft=conf.final_ft)
class IQN(nn.Module): def __init__(self, state_space, act_n, quantile_dim, num_quantiles, hidden_dim, num_hidden, optim_params): """ Rainbow Recurrent IQN IQN: https://arxiv.org/pdf/1806.06923.pdf R2D2: https://openreview.net/pdf?id=r1lyTjAqYX R2D3: https://arxiv.org/abs/1909.01387 """ nn.Module.__init__(self) self.online = Model(state_space, act_n, quantile_dim, num_quantiles, hidden_dim, num_hidden) self.target = deepcopy(self.online) self.loss_func = nn.SmoothL1Loss(reduction="mean") self.optim = RAdam(self.online.parameters(), **optim_params) def forward(self, inp): return self.online(inp) def step(self, state, greedy=False): """ Takes a step into the environment """ return self.online.step(state, greedy) def train_batch(self, rollouts, burn_in_length, sequence_length): """ Trains for a batch of rollouts with the given burn in length and training sequence length """ self.optim.zero_grad() states, actions, rewards, next_states, terminals, hidden_state = rollouts # Add burn in here ####### next_q_vals, next_quantile_vals, next_quantiles, next_hidden = self.target(next_states) num_quantiles = next_quantile_vals[1] next_actions = next_quantile_vals.argmax(-1, keepdim=1) next_actions = next_actions.unsqueeze(1).repeat(1, num_quantiles, 1) next_values = next_quantile_vals.gather(-1, next_actions).squeeze(1) q_vals, quantile_vals, quantiles = self.online(states) action_values = quantile_vals.gather(-1, actions) td_error = next_values.unsqueeze(2) - action_values.unsqueeze(1) quantile_loss = self.loss_func(next_values.unsqueeze(2), action_values.unsqueeze(1)) quantiles = quantiles.unsqueeze(1).repeat(1, self.num_quantiles, 1) penalty = torch.abs(quantiles - (td_error < 0).float().detach()) loss = penalty * quantile_loss # Divide by huber kappa loss = loss.sum(2).mean(1) meaned_loss = loss.mean(1) meaned_loss.backward() self.optim.step() return meaned_loss, loss def train(self, num_batches, batch_size, burn_in_length, sequence_length, online_replay_buffer=None, supervised_replay_buffer=None, supervised_chance=0.25, writer=None): """ Trains R2D3 style with 2 replay buffers """ assert not online_replay_buffer == supervised_replay_buffer == None for batch in range(1, num_batches + 1): buff_choice = np.rand() if(online_replay_buffer is None or buff_choice < supervised_chance): replay_buffer = supervised_replay_buffer else: replay_buffer = online_replay_buffer while(not replay_buffer.ready_to_sample(batch_size)): pass rollouts, idxs, is_weights = replay_buffer.sample(batch_size) loss, new_errors = self.train_batch(rollouts, burn_in_length, sequence_length) replay_buffer.update_priorities(new_errors, idxs) if(writer is not None): if(buff_choice < supervised_chance): writer.add_summary("Supervised Loss", loss, batch) else: writer.add_summary("Online Loss", loss, batch) writer.add_summary("Loss", loss, batch) def update_target(self): """ Updates the target network """ self.target.load_state_dict(self.online.state_dict())
def main(): cifar_train = CIFAR10('.', train=True, transform=transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor() ]), download=True) cifar_test = CIFAR10('.', train=False, transform=transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor() ]), download=True) dl_train = DataLoader(cifar_train, batch_size=16) dl_test = DataLoader(cifar_test, batch_size=16) logdir = "./logdir/Adam" num_epochs = 10 loaders = {'train': dl_train, 'valid': dl_test} model = resnet34() for name, param in model.named_parameters(): param.requires_grad = True model.train() criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters()) runner = dl.SupervisedRunner() runner.train( model=model, criterion=criterion, optimizer=optimizer, loaders=loaders, num_epochs=num_epochs, verbose=True, logdir=logdir, callbacks=[ logger.TensorboardLogger(), AccuracyCallback(num_classes=10) ], ) logdir = "./logdir/AdamW" model.apply(init_weights) optimizer = AdamW() runner.train( model=model, criterion=criterion, optimizer=optimizer, loaders=loaders, num_epochs=num_epochs, verbose=True, logdir=logdir, callbacks=[ logger.TensorboardLogger(), AccuracyCallback(num_classes=10) ], ) logdir = "./logdir/RAdam" model.apply(init_weights) optimizer = RAdam() runner.train( model=model, criterion=criterion, optimizer=optimizer, loaders=loaders, num_epochs=num_epochs, verbose=True, logdir=logdir, callbacks=[ logger.TensorboardLogger(), AccuracyCallback(num_classes=10) ], )
def train(args, persuasive_data_iter, tree_data_iter, model, criterion, device, multitask=False): # initial for training model.train() # build up optimizer if (args.optimizer == "Adam"): optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) elif (args.optimizer == 'AdamW'): optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) elif (args.optimizer == 'Ranger'): optimizer = Ranger(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) elif (args.optimizer == 'Radam'): optimizer = RAdam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) elif (args.optimizer == 'SGD'): optimizer = optim.SGD(model.parameters(), lr=args.lr * 1000, momentum=0.9, weight_decay=args.weight_decay) else: raise NotImplementedError scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step, gamma=args.lr_gamma) grad_clip = args.grad_clip save_path = args.save_path accumulate = args.accumulate print_every = 100 * accumulate eval_every = 25 * accumulate #print_every, eval_every = 2, 2 total_epoch = args.epoch * len(persuasive_data_iter[0]) print('total training step:', total_epoch) persuasive_datas = iter(persuasive_data_iter[0]) if ((tree_data_iter[0] is not None) and multitask): tree_datas = iter(tree_data_iter[0]) multi_alpha = (args.ac_type_alpha, args.link_type_alpha) direct_alpha = (args.adu_alpha, args.para_alpha) alpha = (direct_alpha, multi_alpha) tree_count = args.tree_count best_acc = [0, 0] # start training model.zero_grad() t = time.time() persuasive = [[[], [], []], [[], [], []]] tree_preds = { 'label': collections.defaultdict(list), 'pred': collections.defaultdict(list), 'loss': collections.defaultdict(float), 'count': 0 } for count in range(1, total_epoch + 1): try: datas = next(persuasive_datas) except: persuasive_datas = iter(persuasive_data_iter[0]) datas = next(persuasive_datas) outputs = [] for data in datas: data = convert(data, device) pred = model(**data) outputs.append(pred) labels = { 'adu_direct': [datas[0]['author'], datas[1]['author']], 'para_direct': [datas[0]['para_author'], datas[1]['para_author']], } # simply compare two value loss, outputs, labels = persuasive_cal_score(outputs, labels, criterion, direct_alpha) for i, (p, l) in enumerate(zip(outputs, labels)): persuasive[0][i].append(p) persuasive[1][i].append(l) loss.backward() if (multitask and (count % tree_count == 0)): try: data, label = next(tree_datas) except: tree_datas = iter(tree_data_iter[0]) data, label = next(tree_datas) data = convert(data, device) output = model(**data, multitask=True) output = { 'type': output[0], 'link': output[1], 'link_type': output[2] } label = convert(label, device) loss, loss_stat, output = tree_cal_score(output, label, None, multi_alpha) loss.backward() for key, val in label.items(): tree_preds['label'][key].append(val.detach().cpu()) for key, val in loss_stat.items(): tree_preds['loss'][key] += val update_pred(tree_preds['pred'], output, data['adu_length']) tree_preds['count'] += 1 if (count % accumulate == 0): #utils.clip_grad_norm_(model.parameters(), grad_clip) optimizer.step() optimizer.zero_grad() if (count % eval_every == 0): stat = update(persuasive, criterion, dtype='persuasive') nt = time.time() print( 'now:{}, time: {:.4f}s'.format(count, nt - t), '\npersuasive: [loss: {:.4f}, diff: {:.4f}, acc: {:.4f}]'. format(stat[0][0], stat[0][1], stat[0][2]), '\tdirect: [adu_loss: {:.4f}, adu_acc: {:.4f}, para_loss: {:.4f}, para_acc: {:.4f}]' .format(stat[1][0], stat[1][1], stat[2][0], stat[2][1]), flush=True) if (multitask): stat = update(tree_preds, dtype='tree') print( 'acc: [link_mst: {:.4f}, link: {:.4f}, type: {:.4f}, link_type: {:.4f}]' .format(stat['acc']['link_mst'], stat['acc']['link'], stat['acc']['type'], stat['acc']['link_type'])) print( 'f1: type: [premise: {:.4f}, claim: {:.4f}], link_type: [support{:.4f}, attack: {:.4f}]' .format(stat['type']['premise'], stat['type']['claim'], stat['link_type']['support'], stat['link_type']['attack'])) print('mrr: {:.4f}'.format(stat['mrr_link']), flush=True) t = nt persuasive = [[[], [], []], [[], [], []]] tree_preds = { 'label': collections.defaultdict(list), 'pred': collections.defaultdict(list), 'loss': collections.defaultdict(float), 'count': 0 } scheduler.step() if (count % print_every == 0): dev_acc = test('dev {}'.format(count), persuasive_data_iter[1], tree_data_iter[1], model, criterion, device, alpha) test_acc = test('test {}'.format(count), persuasive_data_iter[2], tree_data_iter[2], model, criterion, device, alpha) if (dev_acc > best_acc[0]): best_acc = [dev_acc, test_acc] torch.save(model.state_dict(), save_path + '/check_{}.pt'.format(count)) print('all finish with acc:', best_acc)