def train_epoch(train_loader, model_list, optimizer_list, epoch, log): global global_step meters = AverageMeterSet() # define criterions class_criterion = nn.CrossEntropyLoss(size_average=False, ignore_index=NO_LABEL).cuda() residual_logit_criterion = losses.symmetric_mse_loss if args.consistency_type == 'mse': consistency_criterion = losses.softmax_mse_loss stabilization_criterion = losses.softmax_mse_loss elif args.consistency_type == 'kl': consistency_criterion = losses.softmax_kl_loss stabilization_criterion = losses.softmax_kl_loss for model in model_list: model.train() end = time.time() for i, (input_list, target) in enumerate(train_loader): meters.update('data_time', time.time() - end) for odx, optimizer in enumerate(optimizer_list): adjust_learning_rate(optimizer, epoch, i, len(train_loader)) meters.update('lr_{0}'.format(odx), optimizer.param_groups[0]['lr']) input_var_list, nograd_input_var_list = [], [] for idx, inp in enumerate(input_list): input_var_list.append(Variable(inp)) nograd_input_var_list.append( Variable(inp, requires_grad=False, volatile=True)) target_var = Variable(target.cuda(async=True)) minibatch_size = len(target_var) labeled_minibatch_size = target_var.data.ne(NO_LABEL).sum() unlabeled_minibatch_size = minibatch_size - labeled_minibatch_size assert labeled_minibatch_size >= 0 and unlabeled_minibatch_size >= 0 meters.update('labeled_minibatch_size', labeled_minibatch_size) meters.update('unlabeled_minibatch_size', unlabeled_minibatch_size) loss_list = [] cls_v_list, nograd_cls_v_list = [], [] cls_i_list, nograd_cls_i_list = [], [] mask_list, nograd_mask_list = [], [] class_logit_list, nograd_class_logit_list = [], [] cons_logit_list = [] in_cons_logit_list, tar_class_logit_list = [], [] # for each student model for mdx, model in enumerate(model_list): # forward class_logit, cons_logit = model(input_var_list[mdx]) nograd_class_logit, nograd_cons_logit = model( nograd_input_var_list[mdx]) # calculate - res_loss, class_loss, consistency_loss - inside each student model res_loss = args.logit_distance_cost * residual_logit_criterion( class_logit, cons_logit) / minibatch_size meters.update('{0}_res_loss'.format(mdx), res_loss.data[0]) class_loss = class_criterion(class_logit, target_var) / minibatch_size meters.update('{0}_class_loss'.format(mdx), res_loss.data[0]) consistency_weight = args.consistency_scale * ramps.sigmoid_rampup( epoch, args.consistency_rampup) nograd_class_logit = Variable(nograd_class_logit.detach().data, requires_grad=False) consistency_loss = consistency_weight * consistency_criterion( cons_logit, nograd_class_logit) / minibatch_size meters.update('{0}_cons_loss'.format(mdx), consistency_loss.data[0]) loss = class_loss + res_loss + consistency_loss loss_list.append(loss) # store variables for calculating the stabilization loss cls_v, cls_i = torch.max(F.softmax(class_logit, dim=1), dim=1) nograd_cls_v, nograd_cls_i = torch.max(F.softmax( nograd_class_logit, dim=1), dim=1) cls_v_list.append(cls_v) cls_i_list.append(cls_i.data.cpu().numpy()) nograd_cls_v_list.append(nograd_cls_v) nograd_cls_i_list.append(nograd_cls_i.data.cpu().numpy()) mask_raw = torch.max(F.softmax(class_logit, dim=1), 1)[0] mask = (mask_raw > args.stable_threshold) nograd_mask_raw = torch.max(F.softmax(nograd_class_logit, dim=1), 1)[0] nograd_mask = (nograd_mask_raw > args.stable_threshold) mask_list.append(mask.data.cpu().numpy()) nograd_mask_list.append(nograd_mask.data.cpu().numpy()) class_logit_list.append(class_logit) cons_logit_list.append(cons_logit) nograd_class_logit_list.append(nograd_class_logit) in_cons_logit = Variable(cons_logit.detach().data, requires_grad=False) in_cons_logit_list.append(in_cons_logit) tar_class_logit = Variable(class_logit.clone().detach().data, requires_grad=False) tar_class_logit_list.append(tar_class_logit) # calculate stablization weight stabilization_weight = args.stabilization_scale * ramps.sigmoid_rampup( epoch, args.stabilization_rampup) if not args.exclude_unlabeled: stabilization_weight = (unlabeled_minibatch_size / minibatch_size) * stabilization_weight model_idx = np.arange(0, len(model_list)) np.random.shuffle(model_idx) for idx in range(0, len(model_idx)): if idx % 2 != 0: continue # l and r construct Dual Student l_mdx, r_mdx = model_idx[idx], model_idx[idx + 1] for sdx in range(0, minibatch_size): l_stable = False # unstable: do not satisfy the 2nd condition if mask_list[l_mdx][sdx] == 0 and nograd_mask_list[l_mdx][ sdx] == 0: tar_class_logit_list[l_mdx][ sdx, ...] = in_cons_logit_list[r_mdx][sdx, ...] # unstable: do not satisfy the 1st condition elif cls_i_list[l_mdx][sdx] != nograd_cls_i_list[l_mdx][sdx]: tar_class_logit_list[l_mdx][ sdx, ...] = in_cons_logit_list[r_mdx][sdx, ...] else: l_stable = True r_stable = False # unstable: do not satisfy the 2nd condition if mask_list[r_mdx][sdx] == 0 and nograd_mask_list[r_mdx][ sdx] == 0: tar_class_logit_list[r_mdx][ sdx, ...] = in_cons_logit_list[l_mdx][sdx, ...] # unstable: do not satisfy the 1st condition elif cls_i_list[r_mdx][sdx] != nograd_cls_i_list[r_mdx][sdx]: tar_class_logit_list[r_mdx][ sdx, ...] = in_cons_logit_list[l_mdx][sdx, ...] else: r_stable = True # calculate stability if both l and r models are stable for a sample if l_stable and r_stable: l_sample_cons = consistency_criterion( cons_logit_list[l_mdx][sdx:sdx + 1, ...], nograd_class_logit_list[r_mdx][sdx:sdx + 1, ...]) r_sample_cons = consistency_criterion( cons_logit_list[r_mdx][sdx:sdx + 1, ...], nograd_class_logit_list[l_mdx][sdx:sdx + 1, ...]) # loss: l -> r if l_sample_cons.data.cpu().numpy( )[0] < r_sample_cons.data.cpu().numpy()[0]: tar_class_logit_list[r_mdx][ sdx, ...] = in_cons_logit_list[l_mdx][sdx, ...] # loss: r -> l elif l_sample_cons.data.cpu().numpy( )[0] > r_sample_cons.data.cpu().numpy()[0]: tar_class_logit_list[l_mdx][ sdx, ...] = in_cons_logit_list[r_mdx][sdx, ...] if args.exclude_unlabeled: l_stabilization_loss = stabilization_weight * stabilization_criterion( cons_logit_list[l_mdx], tar_class_logit_list[r_mdx]) / minibatch_size r_stabilization_loss = stabilization_weight * stabilization_criterion( cons_logit_list[r_mdx], tar_class_logit_list[l_mdx]) / minibatch_size else: for sdx in range(unlabeled_minibatch_size, minibatch_size): tar_class_logit_list[l_mdx][ sdx, ...] = in_cons_logit_list[r_mdx][sdx, ...] tar_class_logit_list[r_mdx][ sdx, ...] = in_cons_logit_list[l_mdx][sdx, ...] l_stabilization_loss = stabilization_weight * stabilization_criterion( cons_logit_list[l_mdx], tar_class_logit_list[r_mdx]) / unlabeled_minibatch_size r_stabilization_loss = stabilization_weight * stabilization_criterion( cons_logit_list[r_mdx], tar_class_logit_list[l_mdx]) / unlabeled_minibatch_size meters.update('{0}_stable_loss'.format(l_mdx), l_stabilization_loss.data[0]) meters.update('{0}_stable_loss'.format(r_mdx), r_stabilization_loss.data[0]) loss_list[l_mdx] = loss_list[l_mdx] + l_stabilization_loss loss_list[r_mdx] = loss_list[r_mdx] + r_stabilization_loss meters.update('{0}_loss'.format(l_mdx), loss_list[l_mdx].data[0]) meters.update('{0}_loss'.format(r_mdx), loss_list[r_mdx].data[0]) for mdx, model in enumerate(model_list): # calculate prec prec = mt_func.accuracy(class_logit_list[mdx].data, target_var.data, topk=(1, ))[0] meters.update('{0}_top1'.format(mdx), prec[0], labeled_minibatch_size) # backward and update optimizer_list[mdx].zero_grad() loss_list[mdx].backward() optimizer_list[mdx].step() # record global_step += 1 meters.update('batch_time', time.time() - end) end = time.time() if i % args.print_freq == 0: LOG.info('Epoch: [{0}][{1}/{2}]\t' 'Batch-T {meters[batch_time]:.3f}\t'.format( epoch, i, len(train_loader), meters=meters)) for mdx, model in enumerate(model_list): cur_class_loss = meters['{0}_class_loss'.format(mdx)].val avg_class_loss = meters['{0}_class_loss'.format(mdx)].avg cur_res_loss = meters['{0}_res_loss'.format(mdx)].val avg_res_loss = meters['{0}_res_loss'.format(mdx)].avg cur_cons_loss = meters['{0}_cons_loss'.format(mdx)].val avg_cons_loss = meters['{0}_cons_loss'.format(mdx)].avg cur_stable_loss = meters['{0}_stable_loss'.format(mdx)].val avg_stable_loss = meters['{0}_stable_loss'.format(mdx)].avg cur_top1_acc = meters['{0}_top1'.format(mdx)].val avg_top1_acc = meters['{0}_top1'.format(mdx)].avg LOG.info( 'model-{0}: Class {1:.4f}({2:.4f})\tRes {3:.4f}({4:.4f})\tCons {5:.4f}({6:.4f})\t' 'Stable {7:.4f}({8:.4f})\tPrec@1 {9:.3f}({10:.3f})\t'. format(mdx, cur_class_loss, avg_class_loss, cur_res_loss, avg_res_loss, cur_cons_loss, avg_cons_loss, cur_stable_loss, avg_stable_loss, cur_top1_acc, avg_top1_acc)) LOG.info('\n') log.record( epoch + i / len(train_loader), { 'step': global_step, **meters.values(), **meters.averages(), **meters.sums() })
def train_epoch(train_loader, l_model, r_model, l_optimizer, r_optimizer, epoch, log): global global_step meters = AverageMeterSet() # define criterions class_criterion = nn.CrossEntropyLoss(size_average=False, ignore_index=NO_LABEL).cuda() residual_logit_criterion = losses.symmetric_mse_loss if args.consistency_type == 'mse': consistency_criterion = losses.softmax_mse_loss stabilization_criterion = losses.softmax_mse_loss elif args.consistency_type == 'kl': consistency_criterion = losses.softmax_kl_loss stabilization_criterion = losses.softmax_kl_loss l_model.train() r_model.train() end = time.time() for i, ((l_input, r_input), target) in enumerate(train_loader): meters.update('data_time', time.time() - end) # adjust learning rate adjust_learning_rate(l_optimizer, epoch, i, len(train_loader)) adjust_learning_rate(r_optimizer, epoch, i, len(train_loader)) meters.update('l_lr', l_optimizer.param_groups[0]['lr']) meters.update('r_lr', r_optimizer.param_groups[0]['lr']) # prepare data l_input_var = Variable(l_input) r_input_var = Variable(r_input) le_input_var = Variable(r_input, requires_grad=False, volatile=True) re_input_var = Variable(l_input, requires_grad=False, volatile=True) target_var = Variable(target.cuda(async=True)) minibatch_size = len(target_var) labeled_minibatch_size = target_var.data.ne(NO_LABEL).sum() unlabeled_minibatch_size = minibatch_size - labeled_minibatch_size assert labeled_minibatch_size >= 0 and unlabeled_minibatch_size >= 0 meters.update('labeled_minibatch_size', labeled_minibatch_size) meters.update('unlabeled_minibatch_size', unlabeled_minibatch_size) # forward l_model_out = l_model(l_input_var) r_model_out = r_model(r_input_var) le_model_out = l_model(le_input_var) re_model_out = r_model(re_input_var) if isinstance(l_model_out, Variable): assert args.logit_distance_cost < 0 l_logit1 = l_model_out r_logit1 = r_model_out le_logit1 = le_model_out re_logit1 = re_model_out elif len(l_model_out) == 2: assert len(r_model_out) == 2 l_logit1, l_logit2 = l_model_out r_logit1, r_logit2 = r_model_out le_logit1, le_logit2 = le_model_out re_logit1, re_logit2 = re_model_out # logit distance loss from mean teacher if args.logit_distance_cost >= 0: l_class_logit, l_cons_logit = l_logit1, l_logit2 r_class_logit, r_cons_logit = r_logit1, r_logit2 le_class_logit, le_cons_logit = le_logit1, le_logit2 re_class_logit, re_cons_logit = re_logit1, re_logit2 l_res_loss = args.logit_distance_cost * residual_logit_criterion( l_class_logit, l_cons_logit) / minibatch_size r_res_loss = args.logit_distance_cost * residual_logit_criterion( r_class_logit, r_cons_logit) / minibatch_size meters.update('l_res_loss', l_res_loss.data[0]) meters.update('r_res_loss', r_res_loss.data[0]) else: l_class_logit, l_cons_logit = l_logit1, l_logit1 r_class_logit, r_cons_logit = r_logit1, r_logit1 le_class_logit, le_cons_logit = le_logit1, le_logit1 re_class_logit, re_cons_logit = re_logit1, re_logit1 l_res_loss = 0.0 r_res_loss = 0.0 meters.update('l_res_loss', 0.0) meters.update('r_res_loss', 0.0) # classification loss l_class_loss = class_criterion(l_class_logit, target_var) / minibatch_size r_class_loss = class_criterion(r_class_logit, target_var) / minibatch_size meters.update('l_class_loss', l_class_loss.data[0]) meters.update('r_class_loss', r_class_loss.data[0]) l_loss, r_loss = l_class_loss, r_class_loss l_loss += l_res_loss r_loss += r_res_loss # consistency loss consistency_weight = args.consistency_scale * ramps.sigmoid_rampup( epoch, args.consistency_rampup) le_class_logit = Variable(le_class_logit.detach().data, requires_grad=False) l_consistency_loss = consistency_weight * consistency_criterion( l_cons_logit, le_class_logit) / minibatch_size meters.update('l_cons_loss', l_consistency_loss.data[0]) l_loss += l_consistency_loss re_class_logit = Variable(re_class_logit.detach().data, requires_grad=False) r_consistency_loss = consistency_weight * consistency_criterion( r_cons_logit, re_class_logit) / minibatch_size meters.update('r_cons_loss', r_consistency_loss.data[0]) r_loss += r_consistency_loss # stabilization loss # value (cls_v) and index (cls_i) of the max probability in the prediction l_cls_v, l_cls_i = torch.max(F.softmax(l_class_logit, dim=1), dim=1) r_cls_v, r_cls_i = torch.max(F.softmax(r_class_logit, dim=1), dim=1) le_cls_v, le_cls_i = torch.max(F.softmax(le_class_logit, dim=1), dim=1) re_cls_v, re_cls_i = torch.max(F.softmax(re_class_logit, dim=1), dim=1) l_cls_i = l_cls_i.data.cpu().numpy() r_cls_i = r_cls_i.data.cpu().numpy() le_cls_i = le_cls_i.data.cpu().numpy() re_cls_i = re_cls_i.data.cpu().numpy() # stable prediction mask l_mask = (l_cls_v > args.stable_threshold).data.cpu().numpy() r_mask = (r_cls_v > args.stable_threshold).data.cpu().numpy() le_mask = (le_cls_v > args.stable_threshold).data.cpu().numpy() re_mask = (re_cls_v > args.stable_threshold).data.cpu().numpy() # detach logit -> for generating stablilization target in_r_cons_logit = Variable(r_cons_logit.detach().data, requires_grad=False) tar_l_class_logit = Variable(l_class_logit.clone().detach().data, requires_grad=False) in_l_cons_logit = Variable(l_cons_logit.detach().data, requires_grad=False) tar_r_class_logit = Variable(r_class_logit.clone().detach().data, requires_grad=False) # generate target for each sample for sdx in range(0, minibatch_size): l_stable = False if l_mask[sdx] == 0 and le_mask[sdx] == 0: # unstable: do not satisfy 2nd condition tar_l_class_logit[sdx, ...] = in_r_cons_logit[sdx, ...] elif l_cls_i[sdx] != le_cls_i[sdx]: # unstable: do not satisfy 1st condition tar_l_class_logit[sdx, ...] = in_r_cons_logit[sdx, ...] else: l_stable = True r_stable = False if r_mask[sdx] == 0 and re_mask[sdx] == 0: # unstable: do not satisfy 2nd condition tar_r_class_logit[sdx, ...] = in_l_cons_logit[sdx, ...] elif r_cls_i[sdx] != re_cls_i[sdx]: # unstable: do not satisfy 1st condition tar_r_class_logit[sdx, ...] = in_l_cons_logit[sdx, ...] else: r_stable = True # calculate stanility if both models are stable for a sample if l_stable and r_stable: # compare by consistency l_sample_cons = consistency_criterion( l_cons_logit[sdx:sdx + 1, ...], le_class_logit[sdx:sdx + 1, ...]) r_sample_cons = consistency_criterion( r_cons_logit[sdx:sdx + 1, ...], re_class_logit[sdx:sdx + 1, ...]) if l_sample_cons.data.cpu().numpy( )[0] < r_sample_cons.data.cpu().numpy()[0]: # loss: l -> r tar_r_class_logit[sdx, ...] = in_l_cons_logit[sdx, ...] elif l_sample_cons.data.cpu().numpy( )[0] > r_sample_cons.data.cpu().numpy()[0]: # loss: r -> l tar_l_class_logit[sdx, ...] = in_r_cons_logit[sdx, ...] # calculate stablization weight stabilization_weight = args.stabilization_scale * ramps.sigmoid_rampup( epoch, args.stabilization_rampup) if not args.exclude_unlabeled: stabilization_weight = (unlabeled_minibatch_size / minibatch_size) * stabilization_weight # stabilization loss for r model if args.exclude_unlabeled: r_stabilization_loss = stabilization_weight * stabilization_criterion( r_cons_logit, tar_l_class_logit) / minibatch_size else: for idx in range(unlabeled_minibatch_size, minibatch_size): tar_l_class_logit[idx, ...] = in_r_cons_logit[idx, ...] r_stabilization_loss = stabilization_weight * stabilization_criterion( r_cons_logit, tar_l_class_logit) / unlabeled_minibatch_size meters.update('r_stable_loss', r_stabilization_loss.data[0]) r_loss += r_stabilization_loss # stabilization loss for l model if args.exclude_unlabeled: l_stabilization_loss = stabilization_weight * stabilization_criterion( l_cons_logit, tar_r_class_logit) / minibatch_size else: for idx in range(unlabeled_minibatch_size, minibatch_size): tar_r_class_logit[idx, ...] = in_l_cons_logit[idx, ...] l_stabilization_loss = stabilization_weight * stabilization_criterion( l_cons_logit, tar_r_class_logit) / unlabeled_minibatch_size meters.update('l_stable_loss', l_stabilization_loss.data[0]) l_loss += l_stabilization_loss if np.isnan(l_loss.data[0]) or np.isnan(r_loss.data[0]): LOG.info('Loss value equals to NAN!') continue assert not (l_loss.data[0] > 1e5), 'L-Loss explosion: {}'.format( l_loss.data[0]) assert not (r_loss.data[0] > 1e5), 'R-Loss explosion: {}'.format( r_loss.data[0]) meters.update('l_loss', l_loss.data[0]) meters.update('r_loss', r_loss.data[0]) # calculate prec and error l_prec = mt_func.accuracy(l_class_logit.data, target_var.data, topk=(1, ))[0] r_prec = mt_func.accuracy(r_class_logit.data, target_var.data, topk=(1, ))[0] meters.update('l_top1', l_prec[0], labeled_minibatch_size) meters.update('l_error1', 100. - l_prec[0], labeled_minibatch_size) meters.update('r_top1', r_prec[0], labeled_minibatch_size) meters.update('r_error1', 100. - r_prec[0], labeled_minibatch_size) # update model l_optimizer.zero_grad() l_loss.backward() l_optimizer.step() r_optimizer.zero_grad() r_loss.backward() r_optimizer.step() # record global_step += 1 meters.update('batch_time', time.time() - end) end = time.time() if i % args.print_freq == 0: LOG.info('Epoch: [{0}][{1}/{2}]\t' 'Batch-T {meters[batch_time]:.3f}\t' 'L-Class {meters[l_class_loss]:.4f}\t' 'R-Class {meters[r_class_loss]:.4f}\t' 'L-Res {meters[l_res_loss]:.4f}\t' 'R-Res {meters[r_res_loss]:.4f}\t' 'L-Cons {meters[l_cons_loss]:.4f}\t' 'R-Cons {meters[r_cons_loss]:.4f}\n' 'L-Stable {meters[l_stable_loss]:.4f}\t' 'R-Stable {meters[r_stable_loss]:.4f}\t' 'L-Prec@1 {meters[l_top1]:.3f}\t' 'R-Prec@1 {meters[r_top1]:.3f}\t'.format( epoch, i, len(train_loader), meters=meters)) log.record( epoch + i / len(train_loader), { 'step': global_step, **meters.values(), **meters.averages(), **meters.sums() })